From 158d73ebdd61eef33831ae5f6990acf07244fc55 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 7 Dec 2021 16:38:29 +0000 Subject: [PATCH 001/157] Revert accidental fast-forward merge from v1.49.0rc1 Revert "Sort internal changes in changelog" Revert "Update CHANGES.md" Revert "1.49.0rc1" Revert "Revert "Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505) (#11527)" Revert "Refactors in `_generate_sync_entry_for_rooms` (#11515)" Revert "Correctly register shutdown handler for presence workers (#11518)" Revert "Fix `ModuleApi.looping_background_call` for non-async functions (#11524)" Revert "Fix 'delete room' admin api to work on incomplete rooms (#11523)" Revert "Correctly ignore invites from ignored users (#11511)" Revert "Fix the test breakage introduced by #11435 as a result of concurrent PRs (#11522)" Revert "Stabilise support for MSC2918 refresh tokens as they have now been merged into the Matrix specification. (#11435)" Revert "Save the OIDC session ID (sid) with the device on login (#11482)" Revert "Add admin API to get some information about federation status (#11407)" Revert "Include bundled aggregations in /sync and related fixes (#11478)" Revert "Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505)" Revert "Update backward extremity docs to make it clear that it does not indicate whether we have fetched an events' `prev_events` (#11469)" Revert "Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. (#11445)" Revert "Add type hints to `synapse/tests/rest/admin` (#11501)" Revert "Revert accidental commits to develop." Revert "Newsfile" Revert "Give `tests.server.setup_test_homeserver` (nominally!) the same behaviour" Revert "Move `tests.utils.setup_test_homeserver` to `tests.server`" Revert "Convert one of the `setup_test_homeserver`s to `make_test_homeserver_synchronous`" Revert "Disambiguate queries on `state_key` (#11497)" Revert "Comments on the /sync tentacles (#11494)" Revert "Clean up tests.storage.test_appservice (#11492)" Revert "Clean up `tests.storage.test_main` to remove use of legacy code. (#11493)" Revert "Clean up `tests.test_visibility` to remove legacy code. (#11495)" Revert "Minor cleanup on recently ported doc pages (#11466)" Revert "Add most of the missing type hints to `synapse.federation`. (#11483)" Revert "Avoid waiting for zombie processes in `synctl stop` (#11490)" Revert "Fix media repository failing when media store path contains symlinks (#11446)" Revert "Add type annotations to `tests.storage.test_appservice`. (#11488)" Revert "`scripts-dev/sign_json`: support for signing events (#11486)" Revert "Add MSC3030 experimental client and federation API endpoints to get the closest event to a given timestamp (#9445)" Revert "Port wiki pages to documentation website (#11402)" Revert "Add a license header and comment. (#11479)" Revert "Clean-up get_version_string (#11468)" Revert "Link background update controller docs to summary (#11475)" Revert "Additional type hints for config module. (#11465)" Revert "Register the login redirect endpoint for v3. (#11451)" Revert "Update openid.md" Revert "Remove mention of OIDC certification from Dex (#11470)" Revert "Add a note about huge pages to our Postgres doc (#11467)" Revert "Don't start Synapse master process if `worker_app` is set (#11416)" Revert "Expose worker & homeserver as entrypoints in `setup.py` (#11449)" Revert "Bundle relations of relations into the `/relations` result. (#11284)" Revert "Fix `LruCache` corruption bug with a `size_callback` that can return 0 (#11454)" Revert "Eliminate a few `Any`s in `LruCache` type hints (#11453)" Revert "Remove unnecessary `json.dumps` from `tests.rest.admin` (#11461)" Revert "Merge branch 'master' into develop" This reverts commit 26b5d2320f62b5eb6262c7614fbdfc364a4dfc02. This reverts commit bce4220f387bf5448387f0ed7d14ed1e41e40747. This reverts commit 966b5d0fa0893c3b628c942dfc232e285417f46d. This reverts commit 088d748f2cb51f03f3bcacc0fb3af1e0f9607737. This reverts commit 14d593f72d10b4d8cb67e3288bb3131ee30ccf59. This reverts commit 2a3ec6facf79f6aae011d9fb6f9ed5e43c7b6bec. This reverts commit eccc49d7554d1fab001e1fefb0fda8ffb254b630. This reverts commit b1ecd19c5d19815b69e425d80f442bf2877cab76. This reverts commit 9c55dedc8c4484e6269451a8c3c10b3e314aeb4a. This reverts commit 2d42e586a8c54be1a83643148358b1651c1ca666. This reverts commit 2f053f3f82ca174cc1c858c75afffae51af8ce0d. This reverts commit a15a893df8428395df7cb95b729431575001c38a. This reverts commit 8b4b153c9e86c04c7db8c74fde4b6a04becbc461. This reverts commit 494ebd7347ba52d702802fba4c3bb13e7bfbc2cf. This reverts commit a77c36989785c0d5565ab9a1169f4f88e512ce8a. This reverts commit 4eb77965cd016181d2111f37d93526e9bb0434f0. This reverts commit 637df95de63196033a6da4a6e286e1d58ea517b6. This reverts commit e5f426cd54609e7f05f8241d845e6e36c5f10d9a. This reverts commit 8cd68b8102eeab1b525712097c1b2e9679c11896. This reverts commit 6cae125e20865c52d770b24278bb7ab8fde5bc0d. This reverts commit 7be88fbf48156b36b6daefb228e1258e7d48cae4. This reverts commit b3fd99b74a3f6f42a9afd1b19ee4c60e38e8e91a. This reverts commit f7ec6e7d9e0dc360d9fb41f3a1afd7bdba1475c7. This reverts commit 5640992d176a499204a0756b1677c9b1575b0a49. This reverts commit d26808dd854006bd26a2366c675428ce0737238c. This reverts commit f91624a5950e14ba9007eed9bfa1c828676d4745. This reverts commit 16d39a5490ce74c901c7a8dbb990c6e83c379207. This reverts commit 8a4c2969874c0b7d72003f2523883eba8a348e83. This reverts commit 49e1356ee3d5d72929c91f778b3a231726c1413c. This reverts commit d2279f471ba8f44d9f578e62b286897a338d8aa1. This reverts commit b50e39df578adc3f86c5efa16bee9035cfdab61b. This reverts commit 858d80bf0f9f656a03992794874081b806e49222. This reverts commit 435f04480728c5d982e1a63c1b2777784bf9cd26. This reverts commit f61462e1be36a51dbf571076afa8e1930cb182f4. This reverts commit a6f1a3abecf8e8fd3e1bff439a06b853df18f194. This reverts commit 84dc50e160a2ec6590813374b5a1e58b97f7a18d. This reverts commit ed635d32853ee0a3e5ec1078679b27e7844a4ac7. This reverts commit 7b62791e001d6a4f8897ed48b3232d7f8fe6aa48. This reverts commit 153194c7717d8016b0eb974c81b1baee7dc1917d. This reverts commit f44d729d4ccae61bc0cdd5774acb3233eb5f7c13. This reverts commit a265fbd397ae72b2d3ea4c9310591ff1d0f3e05c. This reverts commit b9fef1a7cdfcc128fa589a32160e6aa7ed8964d7. This reverts commit b0eb64ff7bf6bde42046e091f8bdea9b7aab5f04. This reverts commit f1795463bf503a6fca909d77f598f641f9349f56. This reverts commit 70cbb1a5e311f609b624e3fae1a1712db639c51e. This reverts commit 42bf0204635213e2c75188b19ee66dc7e7d8a35e. This reverts commit 379f2650cf875f50c59524147ec0e33cfd5ef60c. This reverts commit 7ff22d6da41cd5ca80db95c18b409aea38e49fcd. This reverts commit 5a0b652d36ae4b6d423498c1f2c82c97a49c6f75. This reverts commit 432a174bc192740ac7a0a755009f6099b8363ad9. This reverts commit b14f8a1baf6f500997ae4c1d6a6d72094ce14270, reversing changes made to e713855dca17a7605bae99ea8d71bc7f8657e4b8. --- .github/workflows/tests.yml | 2 +- CHANGES.md | 93 --- debian/changelog | 6 - docker/Dockerfile-workers | 3 - docker/conf-workers/healthcheck.sh.j2 | 6 - docker/configure_workers_and_start.py | 13 - docs/SUMMARY.md | 8 - docs/development/room-dag-concepts.md | 16 +- docs/media_repository.md | 89 +-- .../background_update_controller_callbacks.md | 71 --- docs/modules/writing_a_module.md | 12 +- docs/openid.md | 4 +- ...nning_synapse_on_single_board_computers.md | 74 --- docs/postgres.md | 3 - docs/sample_config.yaml | 38 -- docs/templates.md | 5 - .../administration/admin_api/federation.md | 114 ---- docs/usage/administration/admin_faq.md | 103 ---- .../database_maintenance_tools.md | 18 - docs/usage/administration/state_groups.md | 25 - ...standing_synapse_through_grafana_graphs.md | 84 --- .../administration/useful_sql_for_admins.md | 156 ----- docs/workers.md | 4 +- mypy.ini | 22 +- scripts-dev/complement.sh | 2 +- scripts-dev/federation_client.py | 19 - scripts-dev/sign_json | 24 +- setup.py | 10 +- synapse/__init__.py | 2 +- synapse/api/constants.py | 198 +++--- synapse/app/_base.py | 3 +- synapse/app/generic_worker.py | 8 +- synapse/app/homeserver.py | 8 - synapse/appservice/__init__.py | 3 +- synapse/config/__main__.py | 3 +- synapse/config/appservice.py | 23 +- synapse/config/cache.py | 26 +- synapse/config/cas.py | 5 +- synapse/config/database.py | 13 +- synapse/config/experimental.py | 3 - synapse/config/logger.py | 24 +- synapse/config/oidc.py | 58 +- synapse/config/registration.py | 116 +--- synapse/config/repository.py | 9 +- synapse/config/saml2.py | 21 +- synapse/config/server.py | 20 +- synapse/config/sso.py | 12 +- synapse/config/workers.py | 4 +- synapse/crypto/keyring.py | 30 +- synapse/events/snapshot.py | 5 - synapse/events/utils.py | 64 +- synapse/federation/federation_client.py | 112 +--- synapse/federation/federation_server.py | 58 +- synapse/federation/persistence.py | 4 +- synapse/federation/send_queue.py | 25 +- .../sender/per_destination_queue.py | 13 +- synapse/federation/transport/client.py | 91 +-- .../federation/transport/server/__init__.py | 14 +- synapse/federation/transport/server/_base.py | 48 +- .../federation/transport/server/federation.py | 47 +- synapse/handlers/auth.py | 124 +--- synapse/handlers/device.py | 8 - synapse/handlers/events.py | 5 +- synapse/handlers/federation.py | 61 +- synapse/handlers/initial_sync.py | 30 +- synapse/handlers/message.py | 8 +- synapse/handlers/oidc.py | 58 +- synapse/handlers/pagination.py | 3 + synapse/handlers/presence.py | 2 +- synapse/handlers/register.py | 78 +-- synapse/handlers/room.py | 165 +---- synapse/handlers/room_summary.py | 14 +- synapse/handlers/sso.py | 4 - synapse/handlers/sync.py | 285 +++------ synapse/http/servlet.py | 29 - synapse/module_api/__init__.py | 146 +---- synapse/push/emailpusher.py | 10 +- synapse/push/httppusher.py | 3 +- synapse/push/mailer.py | 72 +-- synapse/push/push_types.py | 136 ---- synapse/python_dependencies.py | 2 +- synapse/replication/http/login.py | 8 - .../slave/storage/_slaved_id_tracker.py | 22 +- .../replication/slave/storage/push_rule.py | 4 + synapse/replication/tcp/streams/events.py | 6 +- synapse/rest/admin/__init__.py | 25 +- synapse/rest/admin/_base.py | 3 +- synapse/rest/admin/devices.py | 21 +- synapse/rest/admin/event_reports.py | 21 +- synapse/rest/admin/federation.py | 135 ---- synapse/rest/admin/groups.py | 5 +- synapse/rest/admin/media.py | 53 +- synapse/rest/admin/registration_tokens.py | 51 +- synapse/rest/admin/rooms.py | 84 ++- synapse/rest/admin/server_notice_servlet.py | 11 +- synapse/rest/admin/statistics.py | 21 +- synapse/rest/admin/users.py | 173 ++---- synapse/rest/client/login.py | 88 +-- synapse/rest/client/register.py | 20 +- synapse/rest/client/relations.py | 16 +- synapse/rest/client/room.py | 67 +- synapse/rest/client/sync.py | 6 +- synapse/rest/media/v1/filepath.py | 115 ++-- synapse/server.py | 5 - synapse/state/__init__.py | 2 +- synapse/state/v1.py | 3 +- synapse/storage/_base.py | 4 +- synapse/storage/background_updates.py | 192 +----- synapse/storage/databases/main/appservice.py | 6 +- synapse/storage/databases/main/devices.py | 50 +- .../databases/main/event_federation.py | 4 +- .../databases/main/event_push_actions.py | 19 +- synapse/storage/databases/main/events.py | 107 +--- .../storage/databases/main/events_worker.py | 581 ++++-------------- .../storage/databases/main/purge_events.py | 2 +- synapse/storage/databases/main/push_rule.py | 11 +- .../storage/databases/main/registration.py | 28 +- synapse/storage/databases/main/roommember.py | 4 +- synapse/storage/databases/main/stream.py | 15 +- .../storage/databases/main/transactions.py | 70 --- synapse/storage/persist_events.py | 3 +- synapse/storage/schema/__init__.py | 6 +- .../delta/65/10_expirable_refresh_tokens.sql | 28 - .../65/11_devices_auth_provider_session.sql | 27 - synapse/storage/util/id_generators.py | 116 ++-- synapse/util/caches/deferred_cache.py | 9 +- synapse/util/caches/lrucache.py | 42 +- synapse/util/linked_list.py | 4 +- synapse/util/versionstring.py | 82 ++- synctl | 58 +- tests/app/test_homeserver_start.py | 31 - tests/config/test_registration_config.py | 78 --- tests/crypto/test_keyring.py | 71 --- tests/federation/transport/test_client.py | 64 -- tests/handlers/test_auth.py | 6 +- tests/handlers/test_cas.py | 40 +- tests/handlers/test_oidc.py | 135 +--- tests/handlers/test_room_summary.py | 94 +-- tests/handlers/test_saml.py | 40 +- tests/push/test_email.py | 9 +- .../test_sharded_event_persister.py | 6 +- tests/rest/admin/test_admin.py | 74 ++- tests/rest/admin/test_background_updates.py | 25 +- tests/rest/admin/test_device.py | 156 ++--- tests/rest/admin/test_event_reports.py | 181 ++---- tests/rest/admin/test_federation.py | 456 -------------- tests/rest/admin/test_media.py | 260 +++----- tests/rest/admin/test_registration_tokens.py | 288 +++------ tests/rest/admin/test_room.py | 338 +++++----- tests/rest/admin/test_server_notice.py | 72 +-- tests/rest/admin/test_statistics.py | 140 ++--- tests/rest/admin/test_user.py | 465 +++++++------- tests/rest/admin/test_username_available.py | 20 +- tests/rest/client/test_auth.py | 265 +------- tests/rest/client/test_relations.py | 239 +------ tests/rest/media/v1/test_filepath.py | 109 +--- .../databases/main/test_events_worker.py | 139 +---- tests/storage/test_appservice.py | 439 ++++++------- tests/storage/test_background_update.py | 119 +--- tests/storage/test_event_chain.py | 4 +- tests/storage/test_main.py | 27 +- tests/storage/test_user_directory.py | 5 +- tests/test_visibility.py | 241 ++++++-- tests/unittest.py | 10 +- tests/util/test_lrucache.py | 12 - 165 files changed, 2708 insertions(+), 7720 deletions(-) delete mode 100644 docker/conf-workers/healthcheck.sh.j2 delete mode 100644 docs/modules/background_update_controller_callbacks.md delete mode 100644 docs/other/running_synapse_on_single_board_computers.md delete mode 100644 docs/usage/administration/admin_api/federation.md delete mode 100644 docs/usage/administration/admin_faq.md delete mode 100644 docs/usage/administration/database_maintenance_tools.md delete mode 100644 docs/usage/administration/state_groups.md delete mode 100644 docs/usage/administration/understanding_synapse_through_grafana_graphs.md delete mode 100644 docs/usage/administration/useful_sql_for_admins.md delete mode 100644 synapse/push/push_types.py delete mode 100644 synapse/rest/admin/federation.py delete mode 100644 synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql delete mode 100644 synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql delete mode 100644 tests/app/test_homeserver_start.py delete mode 100644 tests/config/test_registration_config.py delete mode 100644 tests/federation/transport/test_client.py delete mode 100644 tests/rest/admin/test_federation.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 21c9ee7823c7..8d7e8cafd9e0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -374,7 +374,7 @@ jobs: working-directory: complement/dockerfiles # Run Complement - - run: go test -v -tags synapse_blacklist,msc2403 ./tests/... + - run: go test -v -tags synapse_blacklist,msc2403,msc2946,msc3083 ./tests/... env: COMPLEMENT_BASE_IMAGE: complement-synapse:latest working-directory: complement diff --git a/CHANGES.md b/CHANGES.md index 72e8d64cf750..c283e33876fe 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,96 +1,3 @@ -Synapse 1.49.0rc1 (2021-12-07) -============================== - -We've decided to move the existing, somewhat stagnant pages from the GitHub wiki -to the [documentation website](https://matrix-org.github.io/synapse/latest/). - -This was done for two reasons. The first was to ensure that changes are checked by -multiple authors before being committed (everyone makes mistakes!) and the second -was visibility of the documentation. Not everyone knows that Synapse has some very -useful information hidden away in its GitHub wiki pages. Bringing them to the -documentation website should help with visibility, as well as keep all Synapse documentation -in one, easily-searchable location. - -Note that contributions to the documentation website happen through [GitHub pull -requests](https://github.com/matrix-org/synapse/pulls). Please visit [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org) -if you need help with the process! - - -Features --------- - -- Add [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) experimental client and federation API endpoints to get the closest event to a given timestamp. ([\#9445](https://github.com/matrix-org/synapse/issues/9445)) -- Include bundled relation aggregations during a limited `/sync` request and `/relations` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11284](https://github.com/matrix-org/synapse/issues/11284), [\#11478](https://github.com/matrix-org/synapse/issues/11478)) -- Add plugin support for controlling database background updates. ([\#11306](https://github.com/matrix-org/synapse/issues/11306), [\#11475](https://github.com/matrix-org/synapse/issues/11475), [\#11479](https://github.com/matrix-org/synapse/issues/11479)) -- Support the stable API endpoints for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): the room `/hierarchy` endpoint. ([\#11329](https://github.com/matrix-org/synapse/issues/11329)) -- Add admin API to get some information about federation status with remote servers. ([\#11407](https://github.com/matrix-org/synapse/issues/11407)) -- Support expiry of refresh tokens and expiry of the overall session when refresh tokens are in use. ([\#11425](https://github.com/matrix-org/synapse/issues/11425)) -- Stabilise support for [MSC2918](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) refresh tokens as they have now been merged into the Matrix specification. ([\#11435](https://github.com/matrix-org/synapse/issues/11435), [\#11522](https://github.com/matrix-org/synapse/issues/11522)) -- Update [MSC2918 refresh token](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) support to confirm with the latest revision: accept the `refresh_tokens` parameter in the request body rather than in the URL parameters. ([\#11430](https://github.com/matrix-org/synapse/issues/11430)) -- Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. ([\#11445](https://github.com/matrix-org/synapse/issues/11445)) -- Expose `synapse_homeserver` and `synapse_worker` commands as entry points to run Synapse's main process and worker processes, respectively. Contributed by @Ma27. ([\#11449](https://github.com/matrix-org/synapse/issues/11449)) -- `synctl stop` will now wait for Synapse to exit before returning. ([\#11459](https://github.com/matrix-org/synapse/issues/11459), [\#11490](https://github.com/matrix-org/synapse/issues/11490)) -- Extend the "delete room" admin api to work correctly on rooms which have previously been partially deleted. ([\#11523](https://github.com/matrix-org/synapse/issues/11523)) -- Add support for the `/_matrix/client/v3/login/sso/redirect/{idpId}` API from Matrix v1.1. This endpoint was overlooked when support for v3 endpoints was added in Synapse 1.48.0rc1. ([\#11451](https://github.com/matrix-org/synapse/issues/11451)) - - -Bugfixes --------- - -- Fix using [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) batch sending in combination with event persistence workers. Contributed by @tulir at Beeper. ([\#11220](https://github.com/matrix-org/synapse/issues/11220)) -- Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, properly this time. Also fix a race condition introduced in the previous insufficient fix in Synapse 1.47.0. ([\#11376](https://github.com/matrix-org/synapse/issues/11376)) -- The `/send_join` response now includes the stable `event` field instead of the unstable field from [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083). ([\#11413](https://github.com/matrix-org/synapse/issues/11413)) -- Fix a bug introduced in Synapse 1.47.0 where `send_join` could fail due to an outdated `ijson` version. ([\#11439](https://github.com/matrix-org/synapse/issues/11439), [\#11441](https://github.com/matrix-org/synapse/issues/11441), [\#11460](https://github.com/matrix-org/synapse/issues/11460)) -- Fix a bug introduced in Synapse 1.36.0 which could cause problems fetching event-signing keys from trusted key servers. ([\#11440](https://github.com/matrix-org/synapse/issues/11440)) -- Fix a bug introduced in Synapse 1.47.1 where the media repository would fail to work if the media store path contained any symbolic links. ([\#11446](https://github.com/matrix-org/synapse/issues/11446)) -- Fix an `LruCache` corruption bug, introduced in Synapse 1.38.0, that would cause certain requests to fail until the next Synapse restart. ([\#11454](https://github.com/matrix-org/synapse/issues/11454)) -- Fix a long-standing bug where invites from ignored users were included in incremental syncs. ([\#11511](https://github.com/matrix-org/synapse/issues/11511)) -- Fix a regression in Synapse 1.48.0 where presence workers would not clear their presence updates over replication on shutdown. ([\#11518](https://github.com/matrix-org/synapse/issues/11518)) -- Fix a regression in Synapse 1.48.0 where the module API's `looping_background_call` method would spam errors to the logs when given a non-async function. ([\#11524](https://github.com/matrix-org/synapse/issues/11524)) - - -Updates to the Docker image ---------------------------- - -- Update `Dockerfile-workers` to healthcheck all workers in the container. ([\#11429](https://github.com/matrix-org/synapse/issues/11429)) - - -Improved Documentation ----------------------- - -- Update the media repository documentation. ([\#11415](https://github.com/matrix-org/synapse/issues/11415)) -- Update section about backward extremities in the room DAG concepts doc to correct the misconception about backward extremities indicating whether we have fetched an events' `prev_events`. ([\#11469](https://github.com/matrix-org/synapse/issues/11469)) - - -Internal Changes ----------------- - -- Add `Final` annotation to string constants in `synapse.api.constants` so that they get typed as `Literal`s. ([\#11356](https://github.com/matrix-org/synapse/issues/11356)) -- Add a check to ensure that users cannot start the Synapse master process when `worker_app` is set. ([\#11416](https://github.com/matrix-org/synapse/issues/11416)) -- Add a note about postgres memory management and hugepages to postgres doc. ([\#11467](https://github.com/matrix-org/synapse/issues/11467)) -- Add missing type hints to `synapse.config` module. ([\#11465](https://github.com/matrix-org/synapse/issues/11465)) -- Add missing type hints to `synapse.federation`. ([\#11483](https://github.com/matrix-org/synapse/issues/11483)) -- Add type annotations to `tests.storage.test_appservice`. ([\#11488](https://github.com/matrix-org/synapse/issues/11488), [\#11492](https://github.com/matrix-org/synapse/issues/11492)) -- Add type annotations to some of the configuration surrounding refresh tokens. ([\#11428](https://github.com/matrix-org/synapse/issues/11428)) -- Add type hints to `synapse/tests/rest/admin`. ([\#11501](https://github.com/matrix-org/synapse/issues/11501)) -- Add type hints to storage classes. ([\#11411](https://github.com/matrix-org/synapse/issues/11411)) -- Add wiki pages to documentation website. ([\#11402](https://github.com/matrix-org/synapse/issues/11402)) -- Clean up `tests.storage.test_main` to remove use of legacy code. ([\#11493](https://github.com/matrix-org/synapse/issues/11493)) -- Clean up `tests.test_visibility` to remove legacy code. ([\#11495](https://github.com/matrix-org/synapse/issues/11495)) -- Convert status codes to `HTTPStatus` in `synapse.rest.admin`. ([\#11452](https://github.com/matrix-org/synapse/issues/11452), [\#11455](https://github.com/matrix-org/synapse/issues/11455)) -- Extend the `scripts-dev/sign_json` script to support signing events. ([\#11486](https://github.com/matrix-org/synapse/issues/11486)) -- Improve internal types in push code. ([\#11409](https://github.com/matrix-org/synapse/issues/11409)) -- Improve type annotations in `synapse.module_api`. ([\#11029](https://github.com/matrix-org/synapse/issues/11029)) -- Improve type hints for `LruCache`. ([\#11453](https://github.com/matrix-org/synapse/issues/11453)) -- Preparation for database schema simplifications: disambiguate queries on `state_key`. ([\#11497](https://github.com/matrix-org/synapse/issues/11497)) -- Refactor `backfilled` into specific behavior function arguments (`_persist_events_and_state_updates` and downstream calls). ([\#11417](https://github.com/matrix-org/synapse/issues/11417)) -- Refactor `get_version_string` to fix-up types and duplicated code. ([\#11468](https://github.com/matrix-org/synapse/issues/11468)) -- Refactor various parts of the `/sync` handler. ([\#11494](https://github.com/matrix-org/synapse/issues/11494), [\#11515](https://github.com/matrix-org/synapse/issues/11515)) -- Remove unnecessary `json.dumps` from `tests.rest.admin`. ([\#11461](https://github.com/matrix-org/synapse/issues/11461)) -- Save the OpenID Connect session ID on login. ([\#11482](https://github.com/matrix-org/synapse/issues/11482)) -- Update and clean up recently ported documentation pages. ([\#11466](https://github.com/matrix-org/synapse/issues/11466)) - - Synapse 1.48.0 (2021-11-30) =========================== diff --git a/debian/changelog b/debian/changelog index acc9f6049eca..7deab5936e58 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,9 +1,3 @@ -matrix-synapse-py3 (1.49.0~rc1) stable; urgency=medium - - * New synapse release 1.49.0~rc1. - - -- Synapse Packaging team Tue, 07 Dec 2021 13:52:21 +0000 - matrix-synapse-py3 (1.48.0) stable; urgency=medium * New synapse release 1.48.0. diff --git a/docker/Dockerfile-workers b/docker/Dockerfile-workers index 46f2e17382db..969cf9728658 100644 --- a/docker/Dockerfile-workers +++ b/docker/Dockerfile-workers @@ -21,6 +21,3 @@ VOLUME ["/data"] # files to run the desired worker configuration. Will start supervisord. COPY ./docker/configure_workers_and_start.py /configure_workers_and_start.py ENTRYPOINT ["/configure_workers_and_start.py"] - -HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \ - CMD /bin/sh /healthcheck.sh diff --git a/docker/conf-workers/healthcheck.sh.j2 b/docker/conf-workers/healthcheck.sh.j2 deleted file mode 100644 index 79c621f89ccb..000000000000 --- a/docker/conf-workers/healthcheck.sh.j2 +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/sh -# This healthcheck script is designed to return OK when every -# host involved returns OK -{%- for healthcheck_url in healthcheck_urls %} -curl -fSs {{ healthcheck_url }} || exit 1 -{%- endfor %} diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index adbb551cee7f..f4ac1c22a423 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -474,16 +474,10 @@ def generate_worker_files(environ, config_path: str, data_dir: str): # Determine the load-balancing upstreams to configure nginx_upstream_config = "" - - # At the same time, prepare a list of internal endpoints to healthcheck - # starting with the main process which exists even if no workers do. - healthcheck_urls = ["http://localhost:8080/health"] - for upstream_worker_type, upstream_worker_ports in nginx_upstreams.items(): body = "" for port in upstream_worker_ports: body += " server localhost:%d;\n" % (port,) - healthcheck_urls.append("http://localhost:%d/health" % (port,)) # Add to the list of configured upstreams nginx_upstream_config += NGINX_UPSTREAM_CONFIG_BLOCK.format( @@ -516,13 +510,6 @@ def generate_worker_files(environ, config_path: str, data_dir: str): worker_config=supervisord_config, ) - # healthcheck config - convert( - "/conf/healthcheck.sh.j2", - "/healthcheck.sh", - healthcheck_urls=healthcheck_urls, - ) - # Ensure the logging directory exists log_dir = data_dir + "/logs" if not os.path.exists(log_dir): diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index b05af6d69051..cdedf8bccc28 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -44,7 +44,6 @@ - [Presence router callbacks](modules/presence_router_callbacks.md) - [Account validity callbacks](modules/account_validity_callbacks.md) - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md) - - [Background update controller callbacks](modules/background_update_controller_callbacks.md) - [Porting a legacy module to the new interface](modules/porting_legacy_module.md) - [Workers](workers.md) - [Using `synctl` with Workers](synctl_workers.md) @@ -65,15 +64,9 @@ - [Statistics](admin_api/statistics.md) - [Users](admin_api/user_admin_api.md) - [Server Version](admin_api/version_api.md) - - [Federation](usage/administration/admin_api/federation.md) - [Manhole](manhole.md) - [Monitoring](metrics-howto.md) - - [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md) - - [Useful SQL for Admins](usage/administration/useful_sql_for_admins.md) - - [Database Maintenance Tools](usage/administration/database_maintenance_tools.md) - - [State Groups](usage/administration/state_groups.md) - [Request log format](usage/administration/request_log.md) - - [Admin FAQ](usage/administration/admin_faq.md) - [Scripts]() # Development @@ -101,4 +94,3 @@ # Other - [Dependency Deprecation Policy](deprecation_policy.md) - - [Running Synapse on a Single-Board Computer](other/running_synapse_on_single_board_computers.md) diff --git a/docs/development/room-dag-concepts.md b/docs/development/room-dag-concepts.md index cbc7cf29491c..5eed72bec662 100644 --- a/docs/development/room-dag-concepts.md +++ b/docs/development/room-dag-concepts.md @@ -38,15 +38,16 @@ Most-recent-in-time events in the DAG which are not referenced by any other even The forward extremities of a room are used as the `prev_events` when the next event is sent. -## Backward extremity +## Backwards extremity The current marker of where we have backfilled up to and will generally be the -`prev_events` of the oldest-in-time events we have in the DAG. This gives a starting point when -backfilling history. +oldest-in-time events we know of in the DAG. -When we persist a non-outlier event, we clear it as a backward extremity and set -all of its `prev_events` as the new backward extremities if they aren't already -persisted in the `events` table. +This is an event where we haven't fetched all of the `prev_events` for. + +Once we have fetched all of its `prev_events`, it's unmarked as a backwards +extremity (although we may have formed new backwards extremities from the prev +events during the backfilling process). ## Outliers @@ -55,7 +56,8 @@ We mark an event as an `outlier` when we haven't figured out the state for the room at that point in the DAG yet. We won't *necessarily* have the `prev_events` of an `outlier` in the database, -but it's entirely possible that we *might*. +but it's entirely possible that we *might*. The status of whether we have all of +the `prev_events` is marked as a [backwards extremity](#backwards-extremity). For example, when we fetch the event auth chain or state for a given event, we mark all of those claimed auth events as outliers because we haven't done the diff --git a/docs/media_repository.md b/docs/media_repository.md index ba17f8a856f1..99ee8f1ef7ff 100644 --- a/docs/media_repository.md +++ b/docs/media_repository.md @@ -2,80 +2,29 @@ *Synapse implementation-specific details for the media repository* -The media repository - * stores avatars, attachments and their thumbnails for media uploaded by local - users. - * caches avatars, attachments and their thumbnails for media uploaded by remote - users. - * caches resources and thumbnails used for - [URL previews](development/url_previews.md). +The media repository is where attachments and avatar photos are stored. +It stores attachment content and thumbnails for media uploaded by local users. +It caches attachment content and thumbnails for media uploaded by remote users. -All media in Matrix can be identified by a unique -[MXC URI](https://spec.matrix.org/latest/client-server-api/#matrix-content-mxc-uris), -consisting of a server name and media ID: -``` -mxc:/// -``` +## Storage -## Local Media -Synapse generates 24 character media IDs for content uploaded by local users. -These media IDs consist of upper and lowercase letters and are case-sensitive. -Other homeserver implementations may generate media IDs differently. +Each item of media is assigned a `media_id` when it is uploaded. +The `media_id` is a randomly chosen, URL safe 24 character string. -Local media is recorded in the `local_media_repository` table, which includes -metadata such as MIME types, upload times and file sizes. -Note that this table is shared by the URL cache, which has a different media ID -scheme. +Metadata such as the MIME type, upload time and length are stored in the +sqlite3 database indexed by `media_id`. -### Paths -A file with media ID `aabbcccccccccccccccccccc` and its `128x96` `image/jpeg` -thumbnail, created by scaling, would be stored at: -``` -local_content/aa/bb/cccccccccccccccccccc -local_thumbnails/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale -``` +Content is stored on the filesystem under a `"local_content"` directory. -## Remote Media -When media from a remote homeserver is requested from Synapse, it is assigned -a local `filesystem_id`, with the same format as locally-generated media IDs, -as described above. +Thumbnails are stored under a `"local_thumbnails"` directory. -A record of remote media is stored in the `remote_media_cache` table, which -can be used to map remote MXC URIs (server names and media IDs) to local -`filesystem_id`s. +The item with `media_id` `"aabbccccccccdddddddddddd"` is stored under +`"local_content/aa/bb/ccccccccdddddddddddd"`. Its thumbnail with width +`128` and height `96` and type `"image/jpeg"` is stored under +`"local_thumbnails/aa/bb/ccccccccdddddddddddd/128-96-image-jpeg"` -### Paths -A file from `matrix.org` with `filesystem_id` `aabbcccccccccccccccccccc` and its -`128x96` `image/jpeg` thumbnail, created by scaling, would be stored at: -``` -remote_content/matrix.org/aa/bb/cccccccccccccccccccc -remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale -``` -Older thumbnails may omit the thumbnailing method: -``` -remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg -``` - -Note that `remote_thumbnail/` does not have an `s`. - -## URL Previews -See [URL Previews](development/url_previews.md) for documentation on the URL preview -process. - -When generating previews for URLs, Synapse may download and cache various -resources, including images. These resources are assigned temporary media IDs -of the form `yyyy-mm-dd_aaaaaaaaaaaaaaaa`, where `yyyy-mm-dd` is the current -date and `aaaaaaaaaaaaaaaa` is a random sequence of 16 case-sensitive letters. - -The metadata for these cached resources is stored in the -`local_media_repository` and `local_media_repository_url_cache` tables. - -Resources for URL previews are deleted after a few days. - -### Paths -The file with media ID `yyyy-mm-dd_aaaaaaaaaaaaaaaa` and its `128x96` -`image/jpeg` thumbnail, created by scaling, would be stored at: -``` -url_cache/yyyy-mm-dd/aaaaaaaaaaaaaaaa -url_cache_thumbnails/yyyy-mm-dd/aaaaaaaaaaaaaaaa/128-96-image-jpeg-scale -``` +Remote content is cached under `"remote_content"` directory. Each item of +remote content is assigned a local `"filesystem_id"` to ensure that the +directory structure `"remote_content/server_name/aa/bb/ccccccccdddddddddddd"` +is appropriate. Thumbnails for remote content are stored under +`"remote_thumbnail/server_name/..."` diff --git a/docs/modules/background_update_controller_callbacks.md b/docs/modules/background_update_controller_callbacks.md deleted file mode 100644 index b3e7c259f4ae..000000000000 --- a/docs/modules/background_update_controller_callbacks.md +++ /dev/null @@ -1,71 +0,0 @@ -# Background update controller callbacks - -Background update controller callbacks allow module developers to control (e.g. rate-limit) -how database background updates are run. A database background update is an operation -Synapse runs on its database in the background after it starts. It's usually used to run -database operations that would take too long if they were run at the same time as schema -updates (which are run on startup) and delay Synapse's startup too much: populating a -table with a big amount of data, adding an index on a big table, deleting superfluous data, -etc. - -Background update controller callbacks can be registered using the module API's -`register_background_update_controller_callbacks` method. Only the first module (in order -of appearance in Synapse's configuration file) calling this method can register background -update controller callbacks, subsequent calls are ignored. - -The available background update controller callbacks are: - -### `on_update` - -_First introduced in Synapse v1.49.0_ - -```python -def on_update(update_name: str, database_name: str, one_shot: bool) -> AsyncContextManager[int] -``` - -Called when about to do an iteration of a background update. The module is given the name -of the update, the name of the database, and a flag to indicate whether the background -update will happen in one go and may take a long time (e.g. creating indices). If this last -argument is set to `False`, the update will be run in batches. - -The module must return an async context manager. It will be entered before Synapse runs a -background update; this should return the desired duration of the iteration, in -milliseconds. - -The context manager will be exited when the iteration completes. Note that the duration -returned by the context manager is a target, and an iteration may take substantially longer -or shorter. If the `one_shot` flag is set to `True`, the duration returned is ignored. - -__Note__: Unlike most module callbacks in Synapse, this one is _synchronous_. This is -because asynchronous operations are expected to be run by the async context manager. - -This callback is required when registering any other background update controller callback. - -### `default_batch_size` - -_First introduced in Synapse v1.49.0_ - -```python -async def default_batch_size(update_name: str, database_name: str) -> int -``` - -Called before the first iteration of a background update, with the name of the update and -of the database. The module must return the number of elements to process in this first -iteration. - -If this callback is not defined, Synapse will use a default value of 100. - -### `min_batch_size` - -_First introduced in Synapse v1.49.0_ - -```python -async def min_batch_size(update_name: str, database_name: str) -> int -``` - -Called before running a new batch for a background update, with the name of the update and -of the database. The module must return an integer representing the minimum number of -elements to process in this iteration. This number must be at least 1, and is used to -ensure that progress is always made. - -If this callback is not defined, Synapse will use a default value of 100. diff --git a/docs/modules/writing_a_module.md b/docs/modules/writing_a_module.md index e7c0ffad58bf..7764e066926b 100644 --- a/docs/modules/writing_a_module.md +++ b/docs/modules/writing_a_module.md @@ -71,15 +71,15 @@ Modules **must** register their web resources in their `__init__` method. ## Registering a callback Modules can use Synapse's module API to register callbacks. Callbacks are functions that -Synapse will call when performing specific actions. Callbacks must be asynchronous (unless -specified otherwise), and are split in categories. A single module may implement callbacks -from multiple categories, and is under no obligation to implement all callbacks from the -categories it registers callbacks for. +Synapse will call when performing specific actions. Callbacks must be asynchronous, and +are split in categories. A single module may implement callbacks from multiple categories, +and is under no obligation to implement all callbacks from the categories it registers +callbacks for. Modules can register callbacks using one of the module API's `register_[...]_callbacks` methods. The callback functions are passed to these methods as keyword arguments, with -the callback name as the argument name and the function as its value. A -`register_[...]_callbacks` method exists for each category. +the callback name as the argument name and the function as its value. This is demonstrated +in the example below. A `register_[...]_callbacks` method exists for each category. Callbacks for each category can be found on their respective page of the [Synapse documentation website](https://matrix-org.github.io/synapse). \ No newline at end of file diff --git a/docs/openid.md b/docs/openid.md index ff9de9d5b8bf..c74e8bda606a 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -83,7 +83,7 @@ oidc_providers: ### Dex -[Dex][dex-idp] is a simple, open-source OpenID Connect Provider. +[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. Although it is designed to help building a full-blown provider with an external database, it can be configured with static passwords in a config file. @@ -523,7 +523,7 @@ The synapse config will look like this: email_template: "{{ user.email }}" ``` -### Django OAuth Toolkit +## Django OAuth Toolkit [django-oauth-toolkit](https://github.com/jazzband/django-oauth-toolkit) is a Django application providing out of the box all the endpoints, data and logic diff --git a/docs/other/running_synapse_on_single_board_computers.md b/docs/other/running_synapse_on_single_board_computers.md deleted file mode 100644 index ea14afa8b2df..000000000000 --- a/docs/other/running_synapse_on_single_board_computers.md +++ /dev/null @@ -1,74 +0,0 @@ -## Summary of performance impact of running on resource constrained devices such as SBCs - -I've been running my homeserver on a cubietruck at home now for some time and am often replying to statements like "you need loads of ram to join large rooms" with "it works fine for me". I thought it might be useful to curate a summary of the issues you're likely to run into to help as a scaling-down guide, maybe highlight these for development work or end up as documentation. It seems that once you get up to about 4x1.5GHz arm64 4GiB these issues are no longer a problem. - -- **Platform**: 2x1GHz armhf 2GiB ram [Single-board computers](https://wiki.debian.org/CheapServerBoxHardware), SSD, postgres. - -### Presence - -This is the main reason people have a poor matrix experience on resource constrained homeservers. Element web will frequently be saying the server is offline while the python process will be pegged at 100% cpu. This feature is used to tell when other users are active (have a client app in the foreground) and therefore more likely to respond, but requires a lot of network activity to maintain even when nobody is talking in a room. - -![Screenshot_2020-10-01_19-29-46](https://user-images.githubusercontent.com/71895/94848963-a47a3580-041c-11eb-8b6e-acb772b4259e.png) - -While synapse does have some performance issues with presence [#3971](https://github.com/matrix-org/synapse/issues/3971), the fundamental problem is that this is an easy feature to implement for a centralised service at nearly no overhead, but federation makes it combinatorial [#8055](https://github.com/matrix-org/synapse/issues/8055). There is also a client-side config option which disables the UI and idle tracking [enable_presence_by_hs_url] to blacklist the largest instances but I didn't notice much difference, so I recommend disabling the feature entirely at the server level as well. - -[enable_presence_by_hs_url]: https://github.com/vector-im/element-web/blob/v1.7.8/config.sample.json#L45 - -### Joining - -Joining a "large", federated room will initially fail with the below message in Element web, but waiting a while (10-60mins) and trying again will succeed without any issue. What counts as "large" is not message history, user count, connections to homeservers or even a simple count of the state events, it is instead how long the state resolution algorithm takes. However, each of those numbers are reasonable proxies, so we can use them as estimates since user count is one of the few things you see before joining. - -![Screenshot_2020-10-02_17-15-06](https://user-images.githubusercontent.com/71895/94945781-18771500-04d3-11eb-8419-83c2da73a341.png) - -This is [#1211](https://github.com/matrix-org/synapse/issues/1211) and will also hopefully be mitigated by peeking [matrix-org/matrix-doc#2753](https://github.com/matrix-org/matrix-doc/pull/2753) so at least you don't need to wait for a join to complete before finding out if it's the kind of room you want. Note that you should first disable presence, otherwise it'll just make the situation worse [#3120](https://github.com/matrix-org/synapse/issues/3120). There is a lot of database interaction too, so make sure you've [migrated your data](../postgres.md) from the default sqlite to postgresql. Personally, I recommend patience - once the initial join is complete there's rarely any issues with actually interacting with the room, but if you like you can just block "large" rooms entirely. - -### Sessions - -Anything that requires modifying the device list [#7721](https://github.com/matrix-org/synapse/issues/7721) will take a while to propagate, again taking the client "Offline" until it's complete. This includes signing in and out, editing the public name and verifying e2ee. The main mitigation I recommend is to keep long-running sessions open e.g. by using Firefox SSB "Use this site in App mode" or Chromium PWA "Install Element". - -### Recommended configuration - -Put the below in a new file at /etc/matrix-synapse/conf.d/sbc.yaml to override the defaults in homeserver.yaml. - -``` -# Set to false to disable presence tracking on this homeserver. -use_presence: false - -# When this is enabled, the room "complexity" will be checked before a user -# joins a new remote room. If it is above the complexity limit, the server will -# disallow joining, or will instantly leave. -limit_remote_rooms: - # Uncomment to enable room complexity checking. - #enabled: true - complexity: 3.0 - -# Database configuration -database: - name: psycopg2 - args: - user: matrix-synapse - # Generate a long, secure one with a password manager - password: hunter2 - database: matrix-synapse - host: localhost - cp_min: 5 - cp_max: 10 -``` - -Currently the complexity is measured by [current_state_events / 500](https://github.com/matrix-org/synapse/blob/v1.20.1/synapse/storage/databases/main/events_worker.py#L986). You can find join times and your most complex rooms like this: - -``` -admin@homeserver:~$ zgrep '/client/r0/join/' /var/log/matrix-synapse/homeserver.log* | awk '{print $18, $25}' | sort --human-numeric-sort -29.922sec/-0.002sec /_matrix/client/r0/join/%23debian-fasttrack%3Apoddery.com -182.088sec/0.003sec /_matrix/client/r0/join/%23decentralizedweb-general%3Amatrix.org -911.625sec/-570.847sec /_matrix/client/r0/join/%23synapse%3Amatrix.org - -admin@homeserver:~$ sudo --user postgres psql matrix-synapse --command 'select canonical_alias, joined_members, current_state_events from room_stats_state natural join room_stats_current where canonical_alias is not null order by current_state_events desc fetch first 5 rows only' - canonical_alias | joined_members | current_state_events --------------------------------+----------------+---------------------- - #_oftc_#debian:matrix.org | 871 | 52355 - #matrix:matrix.org | 6379 | 10684 - #irc:matrix.org | 461 | 3751 - #decentralizedweb-general:matrix.org | 997 | 1509 - #whatsapp:maunium.net | 554 | 854 -``` \ No newline at end of file diff --git a/docs/postgres.md b/docs/postgres.md index e4861c1f127f..083b0aaff01f 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -118,9 +118,6 @@ performance: Note that the appropriate values for those fields depend on the amount of free memory the database host has available. -Additionally, admins of large deployments might want to consider using huge pages -to help manage memory, especially when using large values of `shared_buffers`. You -can read more about that [here](https://www.postgresql.org/docs/10/kernel-resources.html#LINUX-HUGE-PAGES). ## Porting from SQLite diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6696ed5d1ef9..ae476d19ac8e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1209,44 +1209,6 @@ oembed: # #session_lifetime: 24h -# Time that an access token remains valid for, if the session is -# using refresh tokens. -# For more information about refresh tokens, please see the manual. -# Note that this only applies to clients which advertise support for -# refresh tokens. -# -# Note also that this is calculated at login time and refresh time: -# changes are not applied to existing sessions until they are refreshed. -# -# By default, this is 5 minutes. -# -#refreshable_access_token_lifetime: 5m - -# Time that a refresh token remains valid for (provided that it is not -# exchanged for another one first). -# This option can be used to automatically log-out inactive sessions. -# Please see the manual for more information. -# -# Note also that this is calculated at login time and refresh time: -# changes are not applied to existing sessions until they are refreshed. -# -# By default, this is infinite. -# -#refresh_token_lifetime: 24h - -# Time that an access token remains valid for, if the session is NOT -# using refresh tokens. -# Please note that not all clients support refresh tokens, so setting -# this to a short value may be inconvenient for some users who will -# then be logged out frequently. -# -# Note also that this is calculated at login time: changes are not applied -# retrospectively to existing sessions for users that have already logged in. -# -# By default, this is infinite. -# -#nonrefreshable_access_token_lifetime: 24h - # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: diff --git a/docs/templates.md b/docs/templates.md index 2b66e9d86294..a240f58b54fd 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -71,12 +71,7 @@ Below are the templates Synapse will look for when generating the content of an * `sender_avatar_url`: the avatar URL (as a `mxc://` URL) for the event's sender * `sender_hash`: a hash of the user ID of the sender - * `msgtype`: the type of the message - * `body_text_html`: html representation of the message - * `body_text_plain`: plaintext representation of the message - * `image_url`: mxc url of an image, when "msgtype" is "m.image" * `link`: a `matrix.to` link to the room - * `avator_url`: url to the room's avator * `reason`: information on the event that triggered the email to be sent. It's an object with the following attributes: * `room_id`: the ID of the room the event was sent in diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md deleted file mode 100644 index 8f9535f57b09..000000000000 --- a/docs/usage/administration/admin_api/federation.md +++ /dev/null @@ -1,114 +0,0 @@ -# Federation API - -This API allows a server administrator to manage Synapse's federation with other homeservers. - -Note: This API is new, experimental and "subject to change". - -## List of destinations - -This API gets the current destination retry timing info for all remote servers. - -The list contains all the servers with which the server federates, -regardless of whether an error occurred or not. -If an error occurs, it may take up to 20 minutes for the error to be displayed here, -as a complete retry must have failed. - -The API is: - -A standard request with no filtering: - -``` -GET /_synapse/admin/v1/federation/destinations -``` - -A response body like the following is returned: - -```json -{ - "destinations":[ - { - "destination": "matrix.org", - "retry_last_ts": 1557332397936, - "retry_interval": 3000000, - "failure_ts": 1557329397936, - "last_successful_stream_ordering": null - } - ], - "total": 1 -} -``` - -To paginate, check for `next_token` and if present, call the endpoint again -with `from` set to the value of `next_token`. This will return a new page. - -If the endpoint does not return a `next_token` then there are no more destinations -to paginate through. - -**Parameters** - -The following query parameters are available: - -- `from` - Offset in the returned list. Defaults to `0`. -- `limit` - Maximum amount of destinations to return. Defaults to `100`. -- `order_by` - The method in which to sort the returned list of destinations. - Valid values are: - - `destination` - Destinations are ordered alphabetically by remote server name. - This is the default. - - `retry_last_ts` - Destinations are ordered by time of last retry attempt in ms. - - `retry_interval` - Destinations are ordered by how long until next retry in ms. - - `failure_ts` - Destinations are ordered by when the server started failing in ms. - - `last_successful_stream_ordering` - Destinations are ordered by the stream ordering - of the most recent successfully-sent PDU. -- `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting - this value to `b` will reverse the above sort order. Defaults to `f`. - -*Caution:* The database only has an index on the column `destination`. -This means that if a different sort order is used, -this can cause a large load on the database, especially for large environments. - -**Response** - -The following fields are returned in the JSON response body: - -- `destinations` - An array of objects, each containing information about a destination. - Destination objects contain the following fields: - - `destination` - string - Name of the remote server to federate. - - `retry_last_ts` - integer - The last time Synapse tried and failed to reach the - remote server, in ms. This is `0` if the last attempt to communicate with the - remote server was successful. - - `retry_interval` - integer - How long since the last time Synapse tried to reach - the remote server before trying again, in ms. This is `0` if no further retrying occuring. - - `failure_ts` - nullable integer - The first time Synapse tried and failed to reach the - remote server, in ms. This is `null` if communication with the remote server has never failed. - - `last_successful_stream_ordering` - nullable integer - The stream ordering of the most - recent successfully-sent [PDU](understanding_synapse_through_grafana_graphs.md#federation) - to this destination, or `null` if this information has not been tracked yet. -- `next_token`: string representing a positive integer - Indication for pagination. See above. -- `total` - integer - Total number of destinations. - -# Destination Details API - -This API gets the retry timing info for a specific remote server. - -The API is: - -``` -GET /_synapse/admin/v1/federation/destinations/ -``` - -A response body like the following is returned: - -```json -{ - "destination": "matrix.org", - "retry_last_ts": 1557332397936, - "retry_interval": 3000000, - "failure_ts": 1557329397936, - "last_successful_stream_ordering": null -} -``` - -**Response** - -The response fields are the same like in the `destinations` array in -[List of destinations](#list-of-destinations) response. diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md deleted file mode 100644 index 3dcad4bbef5d..000000000000 --- a/docs/usage/administration/admin_faq.md +++ /dev/null @@ -1,103 +0,0 @@ -## Admin FAQ - -How do I become a server admin? ---- -If your server already has an admin account you should use the user admin API to promote other accounts to become admins. See [User Admin API](../../admin_api/user_admin_api.md#Change-whether-a-user-is-a-server-administrator-or-not) - -If you don't have any admin accounts yet you won't be able to use the admin API so you'll have to edit the database manually. Manually editing the database is generally not recommended so once you have an admin account, use the admin APIs to make further changes. - -```sql -UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; -``` -What servers are my server talking to? ---- -Run this sql query on your db: -```sql -SELECT * FROM destinations; -``` - -What servers are currently participating in this room? ---- -Run this sql query on your db: -```sql -SELECT DISTINCT split_part(state_key, ':', 2) - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - WHERE room_id = '!cURbafjkfsMDVwdRDQ:matrix.org' AND membership = 'join'; -``` - -What users are registered on my server? ---- -```sql -SELECT NAME from users; -``` - -Manually resetting passwords: ---- -See https://github.com/matrix-org/synapse/blob/master/README.rst#password-reset - -I have a problem with my server. Can I just delete my database and start again? ---- -Deleting your database is unlikely to make anything better. - -It's easy to make the mistake of thinking that you can start again from a clean slate by dropping your database, but things don't work like that in a federated network: lots of other servers have information about your server. - -For example: other servers might think that you are in a room, your server will think that you are not, and you'll probably be unable to interact with that room in a sensible way ever again. - -In general, there are better solutions to any problem than dropping the database. Come and seek help in https://matrix.to/#/#synapse:matrix.org. - -There are two exceptions when it might be sensible to delete your database and start again: -* You have *never* joined any rooms which are federated with other servers. For instance, a local deployment which the outside world can't talk to. -* You are changing the `server_name` in the homeserver configuration. In effect this makes your server a completely new one from the point of view of the network, so in this case it makes sense to start with a clean database. -(In both cases you probably also want to clear out the media_store.) - -I've stuffed up access to my room, how can I delete it to free up the alias? ---- -Using the following curl command: -``` -curl -H 'Authorization: Bearer ' -X DELETE https://matrix.org/_matrix/client/r0/directory/room/ -``` -`` - can be obtained in riot by looking in the riot settings, down the bottom is: -Access Token:\ - -`` - the room alias, eg. #my_room:matrix.org this possibly needs to be URL encoded also, for example %23my_room%3Amatrix.org - -How can I find the lines corresponding to a given HTTP request in my homeserver log? ---- - -Synapse tags each log line according to the HTTP request it is processing. When it finishes processing each request, it logs a line containing the words `Processed request: `. For example: - -``` -2019-02-14 22:35:08,196 - synapse.access.http.8008 - 302 - INFO - GET-37 - ::1 - 8008 - {@richvdh:localhost} Processed request: 0.173sec/0.001sec (0.002sec, 0.000sec) (0.027sec/0.026sec/2) 687B 200 "GET /_matrix/client/r0/sync HTTP/1.1" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Safari/537.36" [0 dbevts]" -``` - -Here we can see that the request has been tagged with `GET-37`. (The tag depends on the method of the HTTP request, so might start with `GET-`, `PUT-`, `POST-`, `OPTIONS-` or `DELETE-`.) So to find all lines corresponding to this request, we can do: - -``` -grep 'GET-37' homeserver.log -``` - -If you want to paste that output into a github issue or matrix room, please remember to surround it with triple-backticks (```) to make it legible (see https://help.github.com/en/articles/basic-writing-and-formatting-syntax#quoting-code). - - -What do all those fields in the 'Processed' line mean? ---- -See [Request log format](request_log.md). - - -What are the biggest rooms on my server? ---- - -```sql -SELECT s.canonical_alias, g.room_id, count(*) AS num_rows -FROM - state_groups_state AS g, - room_stats_state AS s -WHERE g.room_id = s.room_id -GROUP BY s.canonical_alias, g.room_id -ORDER BY num_rows desc -LIMIT 10; -``` - -You can also use the [List Room API](../../admin_api/rooms.md#list-room-api) -and `order_by` `state_events`. diff --git a/docs/usage/administration/database_maintenance_tools.md b/docs/usage/administration/database_maintenance_tools.md deleted file mode 100644 index 92b805d413cb..000000000000 --- a/docs/usage/administration/database_maintenance_tools.md +++ /dev/null @@ -1,18 +0,0 @@ -This blog post by Victor Berger explains how to use many of the tools listed on this page: https://levans.fr/shrink-synapse-database.html - -# List of useful tools and scripts for maintenance Synapse database: - -## [Purge Remote Media API](../../admin_api/media_admin_api.md#purge-remote-media-api) -The purge remote media API allows server admins to purge old cached remote media. - -## [Purge Local Media API](../../admin_api/media_admin_api.md#delete-local-media) -This API deletes the *local* media from the disk of your own server. - -## [Purge History API](../../admin_api/purge_history_api.md) -The purge history API allows server admins to purge historic events from their database, reclaiming disk space. - -## [synapse-compress-state](https://github.com/matrix-org/rust-synapse-compress-state) -Tool for compressing (deduplicating) `state_groups_state` table. - -## [SQL for analyzing Synapse PostgreSQL database stats](useful_sql_for_admins.md) -Some easy SQL that reports useful stats about your Synapse database. \ No newline at end of file diff --git a/docs/usage/administration/state_groups.md b/docs/usage/administration/state_groups.md deleted file mode 100644 index f1dee7accf0f..000000000000 --- a/docs/usage/administration/state_groups.md +++ /dev/null @@ -1,25 +0,0 @@ -# How do State Groups work? - -As a general rule, I encourage people who want to understand the deepest darkest secrets of the database schema to drop by #synapse-dev:matrix.org and ask questions. - -However, one question that comes up frequently is that of how "state groups" work, and why the `state_groups_state` table gets so big, so here's an attempt to answer that question. - -We need to be able to relatively quickly calculate the state of a room at any point in that room's history. In other words, we need to know the state of the room at each event in that room. This is done as follows: - -A sequence of events where the state is the same are grouped together into a `state_group`; the mapping is recorded in `event_to_state_groups`. (Technically speaking, since a state event usually changes the state in the room, we are recording the state of the room *after* the given event id: which is to say, to a handwavey simplification, the first event in a state group is normally a state event, and others in the same state group are normally non-state-events.) - -`state_groups` records, for each state group, the id of the room that we're looking at, and also the id of the first event in that group. (I'm not sure if that event id is used much in practice.) - -Now, if we stored all the room state for each `state_group`, that would be a huge amount of data. Instead, for each state group, we normally store the difference between the state in that group and some other state group, and only occasionally (every 100 state changes or so) record the full state. - -So, most state groups have an entry in `state_group_edges` (don't ask me why it's not a column in `state_groups`) which records the previous state group in the room, and `state_groups_state` records the differences in state since that previous state group. - -A full state group just records the event id for each piece of state in the room at that point. - -## Known bugs with state groups - -There are various reasons that we can end up creating many more state groups than we need: see https://github.com/matrix-org/synapse/issues/3364 for more details. - -## Compression tool - -There is a tool at https://github.com/matrix-org/rust-synapse-compress-state which can compress the `state_groups_state` on a room by-room basis (essentially, it reduces the number of "full" state groups). This can result in dramatic reductions of the storage used. \ No newline at end of file diff --git a/docs/usage/administration/understanding_synapse_through_grafana_graphs.md b/docs/usage/administration/understanding_synapse_through_grafana_graphs.md deleted file mode 100644 index c365cc392309..000000000000 --- a/docs/usage/administration/understanding_synapse_through_grafana_graphs.md +++ /dev/null @@ -1,84 +0,0 @@ -## Understanding Synapse through Grafana graphs - -It is possible to monitor much of the internal state of Synapse using [Prometheus](https://prometheus.io) -metrics and [Grafana](https://grafana.com/). -A guide for configuring Synapse to provide metrics is available [here](../../metrics-howto.md) -and information on setting up Grafana is [here](https://github.com/matrix-org/synapse/tree/master/contrib/grafana). -In this setup, Prometheus will periodically scrape the information Synapse provides and -store a record of it over time. Grafana is then used as an interface to query and -present this information through a series of pretty graphs. - -Once you have grafana set up, and assuming you're using [our grafana dashboard template](https://github.com/matrix-org/synapse/blob/master/contrib/grafana/synapse.json), look for the following graphs when debugging a slow/overloaded Synapse: - -## Message Event Send Time - -![image](https://user-images.githubusercontent.com/1342360/82239409-a1c8e900-9930-11ea-8081-e4614e0c63f4.png) - -This, along with the CPU and Memory graphs, is a good way to check the general health of your Synapse instance. It represents how long it takes for a user on your homeserver to send a message. - -## Transaction Count and Transaction Duration - -![image](https://user-images.githubusercontent.com/1342360/82239985-8d392080-9931-11ea-80d0-843ab2f22e1e.png) - -![image](https://user-images.githubusercontent.com/1342360/82240050-ab068580-9931-11ea-98f1-f94671cbac9a.png) - -These graphs show the database transactions that are occurring the most frequently, as well as those are that are taking the most amount of time to execute. - -![image](https://user-images.githubusercontent.com/1342360/82240192-e86b1300-9931-11ea-9aac-3e2c9bfa6fdc.png) - -In the first graph, we can see obvious spikes corresponding to lots of `get_user_by_id` transactions. This would be useful information to figure out which part of the Synapse codebase is potentially creating a heavy load on the system. However, be sure to cross-reference this with Transaction Duration, which states that `get_users_by_id` is actually a very quick database transaction and isn't causing as much load as others, like `persist_events`: - -![image](https://user-images.githubusercontent.com/1342360/82240467-62030100-9932-11ea-8db9-917f2d977fe1.png) - -Still, it's probably worth investigating why we're getting users from the database that often, and whether it's possible to reduce the amount of queries we make by adjusting our cache factor(s). - -The `persist_events` transaction is responsible for saving new room events to the Synapse database, so can often show a high transaction duration. - -## Federation - -The charts in the "Federation" section show information about incoming and outgoing federation requests. Federation data can be divided into two basic types: - -- PDU (Persistent Data Unit) - room events: messages, state events (join/leave), etc. These are permanently stored in the database. -- EDU (Ephemeral Data Unit) - other data, which need not be stored permanently, such as read receipts, typing notifications. - -The "Outgoing EDUs by type" chart shows the EDUs within outgoing federation requests by type: `m.device_list_update`, `m.direct_to_device`, `m.presence`, `m.receipt`, `m.typing`. - -If you see a large number of `m.presence` EDUs and are having trouble with too much CPU load, you can disable `presence` in the Synapse config. See also [#3971](https://github.com/matrix-org/synapse/issues/3971). - -## Caches - -![image](https://user-images.githubusercontent.com/1342360/82240572-8b239180-9932-11ea-96ff-6b5f0e57ebe5.png) - -![image](https://user-images.githubusercontent.com/1342360/82240666-b8703f80-9932-11ea-86af-9f663988d8da.png) - -This is quite a useful graph. It shows how many times Synapse attempts to retrieve a piece of data from a cache which the cache did not contain, thus resulting in a call to the database. We can see here that the `_get_joined_profile_from_event_id` cache is being requested a lot, and often the data we're after is not cached. - -Cross-referencing this with the Eviction Rate graph, which shows that entries are being evicted from `_get_joined_profile_from_event_id` quite often: - -![image](https://user-images.githubusercontent.com/1342360/82240766-de95df80-9932-11ea-8c15-5acfc57c48da.png) - -we should probably consider raising the size of that cache by raising its cache factor (a multiplier value for the size of an individual cache). Information on doing so is available [here](https://github.com/matrix-org/synapse/blob/ee421e524478c1ad8d43741c27379499c2f6135c/docs/sample_config.yaml#L608-L642) (note that the configuration of individual cache factors through the configuration file is available in Synapse v1.14.0+, whereas doing so through environment variables has been supported for a very long time). Note that this will increase Synapse's overall memory usage. - -## Forward Extremities - -![image](https://user-images.githubusercontent.com/1342360/82241440-13566680-9934-11ea-8b88-ba468db937ed.png) - -Forward extremities are the leaf events at the end of a DAG in a room, aka events that have no children. The more that exist in a room, the more [state resolution](https://spec.matrix.org/v1.1/server-server-api/#room-state-resolution) that Synapse needs to perform (hint: it's an expensive operation). While Synapse has code to prevent too many of these existing at one time in a room, bugs can sometimes make them crop up again. - -If a room has >10 forward extremities, it's worth checking which room is the culprit and potentially removing them using the SQL queries mentioned in [#1760](https://github.com/matrix-org/synapse/issues/1760). - -## Garbage Collection - -![image](https://user-images.githubusercontent.com/1342360/82241911-da6ac180-9934-11ea-9a0d-a311fe22acd0.png) - -Large spikes in garbage collection times (bigger than shown here, I'm talking in the -multiple seconds range), can cause lots of problems in Synapse performance. It's more an -indicator of problems, and a symptom of other problems though, so check other graphs for what might be causing it. - -## Final Thoughts - -If you're still having performance problems with your Synapse instance and you've -tried everything you can, it may just be a lack of system resources. Consider adding -more CPU and RAM, and make use of [worker mode](../../workers.md) -to make use of multiple CPU cores / multiple machines for your homeserver. - diff --git a/docs/usage/administration/useful_sql_for_admins.md b/docs/usage/administration/useful_sql_for_admins.md deleted file mode 100644 index d4aada3272d0..000000000000 --- a/docs/usage/administration/useful_sql_for_admins.md +++ /dev/null @@ -1,156 +0,0 @@ -## Some useful SQL queries for Synapse Admins - -## Size of full matrix db -`SELECT pg_size_pretty( pg_database_size( 'matrix' ) );` -### Result example: -``` -pg_size_pretty ----------------- - 6420 MB -(1 row) -``` -## Show top 20 larger rooms by state events count -```sql -SELECT r.name, s.room_id, s.current_state_events - FROM room_stats_current s - LEFT JOIN room_stats_state r USING (room_id) - ORDER BY current_state_events DESC - LIMIT 20; -``` - -and by state_group_events count: -```sql -SELECT rss.name, s.room_id, count(s.room_id) FROM state_groups_state s -LEFT JOIN room_stats_state rss USING (room_id) -GROUP BY s.room_id, rss.name -ORDER BY count(s.room_id) DESC -LIMIT 20; -``` -plus same, but with join removed for performance reasons: -```sql -SELECT s.room_id, count(s.room_id) FROM state_groups_state s -GROUP BY s.room_id -ORDER BY count(s.room_id) DESC -LIMIT 20; -``` - -## Show top 20 larger tables by row count -```sql -SELECT relname, n_live_tup as rows - FROM pg_stat_user_tables - ORDER BY n_live_tup DESC - LIMIT 20; -``` -This query is quick, but may be very approximate, for exact number of rows use `SELECT COUNT(*) FROM `. -### Result example: -``` -state_groups_state - 161687170 -event_auth - 8584785 -event_edges - 6995633 -event_json - 6585916 -event_reference_hashes - 6580990 -events - 6578879 -received_transactions - 5713989 -event_to_state_groups - 4873377 -stream_ordering_to_exterm - 4136285 -current_state_delta_stream - 3770972 -event_search - 3670521 -state_events - 2845082 -room_memberships - 2785854 -cache_invalidation_stream - 2448218 -state_groups - 1255467 -state_group_edges - 1229849 -current_state_events - 1222905 -users_in_public_rooms - 364059 -device_lists_stream - 326903 -user_directory_search - 316433 -``` - -## Show top 20 rooms by new events count in last 1 day: -```sql -SELECT e.room_id, r.name, COUNT(e.event_id) cnt FROM events e -LEFT JOIN room_stats_state r USING (room_id) -WHERE e.origin_server_ts >= DATE_PART('epoch', NOW() - INTERVAL '1 day') * 1000 GROUP BY e.room_id, r.name ORDER BY cnt DESC LIMIT 20; -``` - -## Show top 20 users on homeserver by sent events (messages) at last month: -```sql -SELECT user_id, SUM(total_events) - FROM user_stats_historical - WHERE TO_TIMESTAMP(end_ts/1000) AT TIME ZONE 'UTC' > date_trunc('day', now() - interval '1 month') - GROUP BY user_id - ORDER BY SUM(total_events) DESC - LIMIT 20; -``` - -## Show last 100 messages from needed user, with room names: -```sql -SELECT e.room_id, r.name, e.event_id, e.type, e.content, j.json FROM events e - LEFT JOIN event_json j USING (room_id) - LEFT JOIN room_stats_state r USING (room_id) - WHERE sender = '@LOGIN:example.com' - AND e.type = 'm.room.message' - ORDER BY stream_ordering DESC - LIMIT 100; -``` - -## Show top 20 larger tables by storage size -```sql -SELECT nspname || '.' || relname AS "relation", - pg_size_pretty(pg_total_relation_size(C.oid)) AS "total_size" - FROM pg_class C - LEFT JOIN pg_namespace N ON (N.oid = C.relnamespace) - WHERE nspname NOT IN ('pg_catalog', 'information_schema') - AND C.relkind <> 'i' - AND nspname !~ '^pg_toast' - ORDER BY pg_total_relation_size(C.oid) DESC - LIMIT 20; -``` -### Result example: -``` -public.state_groups_state - 27 GB -public.event_json - 9855 MB -public.events - 3675 MB -public.event_edges - 3404 MB -public.received_transactions - 2745 MB -public.event_reference_hashes - 1864 MB -public.event_auth - 1775 MB -public.stream_ordering_to_exterm - 1663 MB -public.event_search - 1370 MB -public.room_memberships - 1050 MB -public.event_to_state_groups - 948 MB -public.current_state_delta_stream - 711 MB -public.state_events - 611 MB -public.presence_stream - 530 MB -public.current_state_events - 525 MB -public.cache_invalidation_stream - 466 MB -public.receipts_linearized - 279 MB -public.state_groups - 160 MB -public.device_lists_remote_cache - 124 MB -public.state_group_edges - 122 MB -``` - -## Show rooms with names, sorted by events in this rooms -`echo "select event_json.room_id,room_stats_state.name from event_json,room_stats_state where room_stats_state.room_id=event_json.room_id" | psql synapse | sort | uniq -c | sort -n` -### Result example: -``` - 9459 !FPUfgzXYWTKgIrwKxW:matrix.org | This Week in Matrix - 9459 !FPUfgzXYWTKgIrwKxW:matrix.org | This Week in Matrix (TWIM) - 17799 !iDIOImbmXxwNngznsa:matrix.org | Linux in Russian - 18739 !GnEEPYXUhoaHbkFBNX:matrix.org | Riot Android - 23373 !QtykxKocfZaZOUrTwp:matrix.org | Matrix HQ - 39504 !gTQfWzbYncrtNrvEkB:matrix.org | ru.[matrix] - 43601 !iNmaIQExDMeqdITdHH:matrix.org | Riot - 43601 !iNmaIQExDMeqdITdHH:matrix.org | Riot Web/Desktop -``` - -## Lookup room state info by list of room_id -```sql -SELECT rss.room_id, rss.name, rss.canonical_alias, rss.topic, rss.encryption, rsc.joined_members, rsc.local_users_in_room, rss.join_rules -FROM room_stats_state rss -LEFT JOIN room_stats_current rsc USING (room_id) -WHERE room_id IN (WHERE room_id IN ( - '!OGEhHVWSdvArJzumhm:matrix.org', - '!YTvKGNlinIzlkMTVRl:matrix.org' -) -``` \ No newline at end of file diff --git a/docs/workers.md b/docs/workers.md index fd83e2ddeb1f..17c8bfeef6e2 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -210,7 +210,7 @@ expressions: ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/ - ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ + ^/_matrix/federation/unstable/org.matrix.msc2946/hierarchy/ # Inbound federation transaction request ^/_matrix/federation/v1/send/ @@ -223,7 +223,7 @@ expressions: ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ - ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ + ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ diff --git a/mypy.ini b/mypy.ini index 1caf807e8505..bc4f59154d9e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -33,6 +33,7 @@ exclude = (?x) |synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_push_actions.py |synapse/storage/databases/main/events_bg_updates.py + |synapse/storage/databases/main/events_worker.py |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py @@ -86,6 +87,9 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/admin/test_admin.py + |tests/rest/admin/test_device.py + |tests/rest/admin/test_media.py + |tests/rest/admin/test_server_notice.py |tests/rest/admin/test_user.py |tests/rest/admin/test_username_available.py |tests/rest/client/test_account.py @@ -108,6 +112,7 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/state/test_v2.py |tests/storage/test_account_data.py + |tests/storage/test_appservice.py |tests/storage/test_background_update.py |tests/storage/test_base.py |tests/storage/test_client_ips.py @@ -120,6 +125,7 @@ exclude = (?x) |tests/test_server.py |tests/test_state.py |tests/test_terms_auth.py + |tests/test_visibility.py |tests/unittest.py |tests/util/caches/test_cached_call.py |tests/util/caches/test_deferred_cache.py @@ -154,21 +160,12 @@ disallow_untyped_defs = True [mypy-synapse.events.*] disallow_untyped_defs = True -[mypy-synapse.federation.*] -disallow_untyped_defs = True - -[mypy-synapse.federation.transport.client] -disallow_untyped_defs = False - [mypy-synapse.handlers.*] disallow_untyped_defs = True [mypy-synapse.metrics.*] disallow_untyped_defs = True -[mypy-synapse.module_api.*] -disallow_untyped_defs = True - [mypy-synapse.push.*] disallow_untyped_defs = True @@ -187,9 +184,6 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.directory] disallow_untyped_defs = True -[mypy-synapse.storage.databases.main.events_worker] -disallow_untyped_defs = True - [mypy-synapse.storage.databases.main.room_batch] disallow_untyped_defs = True @@ -226,10 +220,6 @@ disallow_untyped_defs = True [mypy-tests.rest.client.test_directory] disallow_untyped_defs = True -[mypy-tests.federation.transport.test_client] -disallow_untyped_defs = True - - ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. ;; The `typeshed` project maintains stubs here: diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 53295b58fca9..29568eded849 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index c72e19f61d62..6f76c08fcff2 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -15,25 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -""" -Script for signing and sending federation requests. - -Some tips on doing the join dance with this: - - room_id=... - user_id=... - - # make_join - federation_client.py "/_matrix/federation/v1/make_join/$room_id/$user_id?ver=5" > make_join.json - - # sign - jq -M .event make_join.json | sign_json --sign-event-room-version=$(jq -r .room_version make_join.json) -o signed-join.json - - # send_join - federation_client.py -X PUT "/_matrix/federation/v2/send_join/$room_id/x" --body $( send_join.json -""" - import argparse import base64 import json diff --git a/scripts-dev/sign_json b/scripts-dev/sign_json index 945954310610..6ac55ef2f704 100755 --- a/scripts-dev/sign_json +++ b/scripts-dev/sign_json @@ -22,8 +22,6 @@ import yaml from signedjson.key import read_signing_keys from signedjson.sign import sign_json -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.util import json_encoder @@ -70,16 +68,6 @@ Example usage: ), ) - parser.add_argument( - "--sign-event-room-version", - type=str, - help=( - "Sign the JSON as an event for the given room version, rather than raw JSON. " - "This means that we will add a 'hashes' object, and redact the event before " - "signing." - ), - ) - input_args = parser.add_mutually_exclusive_group() input_args.add_argument("input_data", nargs="?", help="Raw JSON to be signed.") @@ -128,17 +116,7 @@ Example usage: print("Input json was not an object", file=sys.stderr) sys.exit(1) - if args.sign_event_room_version: - room_version = KNOWN_ROOM_VERSIONS.get(args.sign_event_room_version) - if not room_version: - print( - f"Unknown room version {args.sign_event_room_version}", file=sys.stderr - ) - sys.exit(1) - add_hashes_and_signatures(room_version, obj, args.server_name, keys[0]) - else: - sign_json(obj, args.server_name, keys[0]) - + sign_json(obj, args.server_name, keys[0]) for c in json_encoder.iterencode(obj): args.output.write(c) args.output.write("\n") diff --git a/setup.py b/setup.py index 2c6fb9aacb45..0ce8beb004a8 100755 --- a/setup.py +++ b/setup.py @@ -119,9 +119,7 @@ def exec_file(path_segments): # Tests assume that all optional dependencies are installed. # # parameterized_class decorator was introduced in parameterized 0.7.0 -# -# We use `mock` library as that backports `AsyncMock` to Python 3.6 -CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"] +CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"] CONDITIONAL_REQUIREMENTS["dev"] = ( CONDITIONAL_REQUIREMENTS["lint"] @@ -152,12 +150,6 @@ def exec_file(path_segments): long_description=long_description, long_description_content_type="text/x-rst", python_requires="~=3.6", - entry_points={ - "console_scripts": [ - "synapse_homeserver = synapse.app.homeserver:main", - "synapse_worker = synapse.app.generic_worker:main", - ] - }, classifiers=[ "Development Status :: 5 - Production/Stable", "Topic :: Communications :: Chat", diff --git a/synapse/__init__.py b/synapse/__init__.py index 6369f18a535a..3cd1ce6070f4 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ except ImportError: pass -__version__ = "1.49.0rc1" +__version__ = "1.48.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f7d29b431936..a33ac341614a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,8 +17,6 @@ """Contains constants from the specification.""" -from typing_extensions import Final - # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 @@ -41,125 +39,125 @@ class Membership: """Represents the membership states of a user in a room.""" - INVITE: Final = "invite" - JOIN: Final = "join" - KNOCK: Final = "knock" - LEAVE: Final = "leave" - BAN: Final = "ban" - LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN) + INVITE = "invite" + JOIN = "join" + KNOCK = "knock" + LEAVE = "leave" + BAN = "ban" + LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) class PresenceState: """Represents the presence state of a user.""" - OFFLINE: Final = "offline" - UNAVAILABLE: Final = "unavailable" - ONLINE: Final = "online" - BUSY: Final = "org.matrix.msc3026.busy" + OFFLINE = "offline" + UNAVAILABLE = "unavailable" + ONLINE = "online" + BUSY = "org.matrix.msc3026.busy" class JoinRules: - PUBLIC: Final = "public" - KNOCK: Final = "knock" - INVITE: Final = "invite" - PRIVATE: Final = "private" + PUBLIC = "public" + KNOCK = "knock" + INVITE = "invite" + PRIVATE = "private" # As defined for MSC3083. - RESTRICTED: Final = "restricted" + RESTRICTED = "restricted" class RestrictedJoinRuleTypes: """Understood types for the allow rules in restricted join rules.""" - ROOM_MEMBERSHIP: Final = "m.room_membership" + ROOM_MEMBERSHIP = "m.room_membership" class LoginType: - PASSWORD: Final = "m.login.password" - EMAIL_IDENTITY: Final = "m.login.email.identity" - MSISDN: Final = "m.login.msisdn" - RECAPTCHA: Final = "m.login.recaptcha" - TERMS: Final = "m.login.terms" - SSO: Final = "m.login.sso" - DUMMY: Final = "m.login.dummy" - REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token" + PASSWORD = "m.login.password" + EMAIL_IDENTITY = "m.login.email.identity" + MSISDN = "m.login.msisdn" + RECAPTCHA = "m.login.recaptcha" + TERMS = "m.login.terms" + SSO = "m.login.sso" + DUMMY = "m.login.dummy" + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by # an appservice to register a new user. -APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service" +APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service" class EventTypes: - Member: Final = "m.room.member" - Create: Final = "m.room.create" - Tombstone: Final = "m.room.tombstone" - JoinRules: Final = "m.room.join_rules" - PowerLevels: Final = "m.room.power_levels" - Aliases: Final = "m.room.aliases" - Redaction: Final = "m.room.redaction" - ThirdPartyInvite: Final = "m.room.third_party_invite" - RelatedGroups: Final = "m.room.related_groups" - - RoomHistoryVisibility: Final = "m.room.history_visibility" - CanonicalAlias: Final = "m.room.canonical_alias" - Encrypted: Final = "m.room.encrypted" - RoomAvatar: Final = "m.room.avatar" - RoomEncryption: Final = "m.room.encryption" - GuestAccess: Final = "m.room.guest_access" + Member = "m.room.member" + Create = "m.room.create" + Tombstone = "m.room.tombstone" + JoinRules = "m.room.join_rules" + PowerLevels = "m.room.power_levels" + Aliases = "m.room.aliases" + Redaction = "m.room.redaction" + ThirdPartyInvite = "m.room.third_party_invite" + RelatedGroups = "m.room.related_groups" + + RoomHistoryVisibility = "m.room.history_visibility" + CanonicalAlias = "m.room.canonical_alias" + Encrypted = "m.room.encrypted" + RoomAvatar = "m.room.avatar" + RoomEncryption = "m.room.encryption" + GuestAccess = "m.room.guest_access" # These are used for validation - Message: Final = "m.room.message" - Topic: Final = "m.room.topic" - Name: Final = "m.room.name" + Message = "m.room.message" + Topic = "m.room.topic" + Name = "m.room.name" - ServerACL: Final = "m.room.server_acl" - Pinned: Final = "m.room.pinned_events" + ServerACL = "m.room.server_acl" + Pinned = "m.room.pinned_events" - Retention: Final = "m.room.retention" + Retention = "m.room.retention" - Dummy: Final = "org.matrix.dummy_event" + Dummy = "org.matrix.dummy_event" - SpaceChild: Final = "m.space.child" - SpaceParent: Final = "m.space.parent" + SpaceChild = "m.space.child" + SpaceParent = "m.space.parent" - MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion" - MSC2716_BATCH: Final = "org.matrix.msc2716.batch" - MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + MSC2716_INSERTION = "org.matrix.msc2716.insertion" + MSC2716_BATCH = "org.matrix.msc2716.batch" + MSC2716_MARKER = "org.matrix.msc2716.marker" class ToDeviceEventTypes: - RoomKeyRequest: Final = "m.room_key_request" + RoomKeyRequest = "m.room_key_request" class DeviceKeyAlgorithms: """Spec'd algorithms for the generation of per-device keys""" - ED25519: Final = "ed25519" - CURVE25519: Final = "curve25519" - SIGNED_CURVE25519: Final = "signed_curve25519" + ED25519 = "ed25519" + CURVE25519 = "curve25519" + SIGNED_CURVE25519 = "signed_curve25519" class EduTypes: - Presence: Final = "m.presence" + Presence = "m.presence" class RejectedReason: - AUTH_ERROR: Final = "auth_error" + AUTH_ERROR = "auth_error" class RoomCreationPreset: - PRIVATE_CHAT: Final = "private_chat" - PUBLIC_CHAT: Final = "public_chat" - TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat" + PRIVATE_CHAT = "private_chat" + PUBLIC_CHAT = "public_chat" + TRUSTED_PRIVATE_CHAT = "trusted_private_chat" class ThirdPartyEntityKind: - USER: Final = "user" - LOCATION: Final = "location" + USER = "user" + LOCATION = "location" -ServerNoticeMsgType: Final = "m.server_notice" -ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached" +ServerNoticeMsgType = "m.server_notice" +ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" class UserTypes: @@ -167,91 +165,91 @@ class UserTypes: 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ - SUPPORT: Final = "support" - BOT: Final = "bot" - ALL_USER_TYPES: Final = (SUPPORT, BOT) + SUPPORT = "support" + BOT = "bot" + ALL_USER_TYPES = (SUPPORT, BOT) class RelationTypes: """The types of relations known to this server.""" - ANNOTATION: Final = "m.annotation" - REPLACE: Final = "m.replace" - REFERENCE: Final = "m.reference" - THREAD: Final = "io.element.thread" + ANNOTATION = "m.annotation" + REPLACE = "m.replace" + REFERENCE = "m.reference" + THREAD = "io.element.thread" class LimitBlockingTypes: """Reasons that a server may be blocked""" - MONTHLY_ACTIVE_USER: Final = "monthly_active_user" - HS_DISABLED: Final = "hs_disabled" + MONTHLY_ACTIVE_USER = "monthly_active_user" + HS_DISABLED = "hs_disabled" class EventContentFields: """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 - LABELS: Final = "org.matrix.labels" + LABELS = "org.matrix.labels" # Timestamp to delete the event after # cf https://github.com/matrix-org/matrix-doc/pull/2228 - SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after" + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" # cf https://github.com/matrix-org/matrix-doc/pull/1772 - ROOM_TYPE: Final = "type" + ROOM_TYPE = "type" # Whether a room can federate. - FEDERATE: Final = "m.federate" + FEDERATE = "m.federate" # The creator of the room, as used in `m.room.create` events. - ROOM_CREATOR: Final = "creator" + ROOM_CREATOR = "creator" # Used in m.room.guest_access events. - GUEST_ACCESS: Final = "guest_access" + GUEST_ACCESS = "guest_access" # Used on normal messages to indicate they were historically imported after the fact - MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical" + MSC2716_HISTORICAL = "org.matrix.msc2716.historical" # For "insertion" events to indicate what the next batch ID should be in # order to connect to it - MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id" + MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id" # Used on "batch" events to indicate which insertion event it connects to - MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id" + MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id" # For "marker" events - MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion" + MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" # The authorising user for joining a restricted room. - AUTHORISING_USER: Final = "join_authorised_via_users_server" + AUTHORISING_USER = "join_authorised_via_users_server" class RoomTypes: """Understood values of the room_type field of m.room.create events.""" - SPACE: Final = "m.space" + SPACE = "m.space" class RoomEncryptionAlgorithms: - MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2" - DEFAULT: Final = MEGOLM_V1_AES_SHA2 + MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" + DEFAULT = MEGOLM_V1_AES_SHA2 class AccountDataTypes: - DIRECT: Final = "m.direct" - IGNORED_USER_LIST: Final = "m.ignored_user_list" + DIRECT = "m.direct" + IGNORED_USER_LIST = "m.ignored_user_list" class HistoryVisibility: - INVITED: Final = "invited" - JOINED: Final = "joined" - SHARED: Final = "shared" - WORLD_READABLE: Final = "world_readable" + INVITED = "invited" + JOINED = "joined" + SHARED = "shared" + WORLD_READABLE = "world_readable" class GuestAccess: - CAN_JOIN: Final = "can_join" + CAN_JOIN = "can_join" # anything that is not "can_join" is considered "forbidden", but for completeness: - FORBIDDEN: Final = "forbidden" + FORBIDDEN = "forbidden" class ReadReceiptEventFields: - MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" + MSC2285_HIDDEN = "org.matrix.msc2285.hidden" diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 5fc59c1be11d..807ee3d46ef1 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -32,7 +32,6 @@ Iterable, List, NoReturn, - Optional, Tuple, cast, ) @@ -130,7 +129,7 @@ def start_worker_reactor( def start_reactor( appname: str, soft_file_limit: int, - gc_thresholds: Optional[Tuple[int, int, int]], + gc_thresholds: Tuple[int, int, int], pid_file: str, daemonize: bool, print_pidfile: bool, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e256de200355..502cc8e8d1e7 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -113,7 +113,6 @@ ) from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.room import RoomWorkerStore -from synapse.storage.databases.main.room_batch import RoomBatchStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore from synapse.storage.databases.main.stats import StatsStore @@ -241,7 +240,6 @@ class GenericWorkerSlavedStore( SlavedEventStore, SlavedKeyStore, RoomWorkerStore, - RoomBatchStore, DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, @@ -505,10 +503,6 @@ def start(config_options: List[str]) -> None: _base.start_worker_reactor("synapse-generic-worker", config) -def main() -> None: +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) - - -if __name__ == "__main__": - main() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index dd76e0732108..7e09530ad23c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -194,7 +194,6 @@ def _configure_named_resource( { "/_matrix/client/api/v1": client_resource, "/_matrix/client/r0": client_resource, - "/_matrix/client/v1": client_resource, "/_matrix/client/v3": client_resource, "/_matrix/client/unstable": client_resource, "/_matrix/client/v2_alpha": client_resource, @@ -358,13 +357,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer: # generating config files and shouldn't try to continue. sys.exit(0) - if config.worker.worker_app: - raise ConfigError( - "You have specified `worker_app` in the config but are attempting to start a non-worker " - "instance. Please use `python -m synapse.app.generic_worker` instead (or remove the option if this is the main process)." - ) - sys.exit(1) - events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f9d3bd337d3b..6504c6bd3f59 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. import logging import re -from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Match, Optional from synapse.api.constants import EventTypes @@ -28,7 +27,7 @@ logger = logging.getLogger(__name__) -class ApplicationServiceState(Enum): +class ApplicationServiceState: DOWN = "down" UP = "up" diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index b2a7a89a3563..c555f5f91407 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import List from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig -def main(args: List[str]) -> None: +def main(args): action = args[1] if len(args) > 1 and args[1] == "read" else None # If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]` # will be the key to read. diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index e4bb7224a410..1ebea88db2a2 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -1,5 +1,4 @@ # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +13,14 @@ # limitations under the License. import logging -from typing import Dict, List +from typing import Dict from urllib import parse as urlparse import yaml from netaddr import IPSet from synapse.appservice import ApplicationService -from synapse.types import JsonDict, UserID +from synapse.types import UserID from ._base import Config, ConfigError @@ -31,12 +30,12 @@ class AppServiceConfig(Config): section = "appservice" - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) self.track_appservice_user_ips = config.get("track_appservice_user_ips", False) - def generate_config_section(cls, **kwargs) -> str: + def generate_config_section(cls, **kwargs): return """\ # A list of application service config files to use # @@ -51,9 +50,7 @@ def generate_config_section(cls, **kwargs) -> str: """ -def load_appservices( - hostname: str, config_files: List[str] -) -> List[ApplicationService]: +def load_appservices(hostname, config_files): """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): logger.warning("Expected %s to be a list of AS config files.", config_files) @@ -96,9 +93,7 @@ def load_appservices( return appservices -def _load_appservice( - hostname: str, as_info: JsonDict, config_filename: str -) -> ApplicationService: +def _load_appservice(hostname, as_info, config_filename): required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: if not isinstance(as_info.get(field), str): @@ -120,9 +115,9 @@ def _load_appservice( user_id = user.to_string() # Rate limiting for users of this AS is on by default (excludes sender) - rate_limited = as_info.get("rate_limited") - if not isinstance(rate_limited, bool): - rate_limited = True + rate_limited = True + if isinstance(as_info.get("rate_limited"), bool): + rate_limited = as_info.get("rate_limited") # namespace checks if not isinstance(as_info.get("namespaces"), dict): diff --git a/synapse/config/cache.py b/synapse/config/cache.py index d9d85f98e155..f05445553473 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 Matrix.org Foundation C.I.C. +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ import threading from typing import Callable, Dict, Optional -import attr - from synapse.python_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError @@ -36,13 +34,13 @@ _DEFAULT_EVENT_CACHE_SIZE = "10K" -@attr.s(slots=True, auto_attribs=True) class CacheProperties: - # The default factor size for all caches - default_factor_size: float = float( - os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) - ) - resize_all_caches_func: Optional[Callable[[], None]] = None + def __init__(self): + # The default factor size for all caches + self.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + self.resize_all_caches_func = None properties = CacheProperties() @@ -64,7 +62,7 @@ def _canonicalise_cache_name(cache_name: str) -> str: def add_resizable_cache( cache_name: str, cache_resize_callback: Callable[[float], None] -) -> None: +): """Register a cache that's size can dynamically change Args: @@ -93,7 +91,7 @@ class CacheConfig(Config): _environ = os.environ @staticmethod - def reset() -> None: + def reset(): """Resets the caches to their defaults. Used for tests.""" properties.default_factor_size = float( os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) @@ -102,7 +100,7 @@ def reset() -> None: with _CACHES_LOCK: _CACHES.clear() - def generate_config_section(self, **kwargs) -> str: + def generate_config_section(self, **kwargs): return """\ ## Caching ## @@ -164,7 +162,7 @@ def generate_config_section(self, **kwargs) -> str: #sync_response_cache_duration: 2m """ - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size( config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) ) @@ -234,7 +232,7 @@ def read_config(self, config, **kwargs) -> None: # needing an instance of Config properties.resize_all_caches_func = self.resize_all_caches - def resize_all_caches(self) -> None: + def resize_all_caches(self): """Ensure all cache sizes are up to date For each cache, run the mapped callback function with either diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 6f2754092e76..3f818140432f 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -1,5 +1,4 @@ # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +28,7 @@ class CasConfig(Config): section = "cas" - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): cas_config = config.get("cas_config", None) self.cas_enabled = cas_config and cas_config.get("enabled", True) @@ -52,7 +51,7 @@ def read_config(self, config, **kwargs) -> None: self.cas_displayname_attribute = None self.cas_required_attributes = [] - def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ # Enable Central Authentication Service (CAS) for registration and login. # diff --git a/synapse/config/database.py b/synapse/config/database.py index 06ccf15cd9f8..651e31b57621 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -1,5 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2020-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import logging import os @@ -120,7 +119,7 @@ def __init__(self, *args, **kwargs): self.databases = [] - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): # We *experimentally* support specifying multiple databases via the # `databases` key. This is a map from a label to database config in the # same format as the `database` config option, plus an extra @@ -164,12 +163,12 @@ def read_config(self, config, **kwargs) -> None: self.databases = [DatabaseConnectionConfig("master", database_config)] self.set_databasepath(database_path) - def generate_config_section(self, data_dir_path, **kwargs) -> str: + def generate_config_section(self, data_dir_path, **kwargs): return DEFAULT_CONFIG % { "database_path": os.path.join(data_dir_path, "homeserver.db") } - def read_arguments(self, args: argparse.Namespace) -> None: + def read_arguments(self, args): """ Cases for the cli input: - If no databases are configured and no database_path is set, raise. @@ -195,7 +194,7 @@ def read_arguments(self, args: argparse.Namespace) -> None: else: logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) - def set_databasepath(self, database_path: str) -> None: + def set_databasepath(self, database_path): if database_path != ":memory:": database_path = self.abspath(database_path) @@ -203,7 +202,7 @@ def set_databasepath(self, database_path: str) -> None: self.databases[0].config["args"]["database"] = database_path @staticmethod - def add_arguments(parser: argparse.ArgumentParser) -> None: + def add_arguments(parser): db_group = parser.add_argument_group("database") db_group.add_argument( "-d", diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d78a15097c87..8b098ad48d56 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -46,6 +46,3 @@ def read_config(self, config: JsonDict, **kwargs): # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) - - # MSC3030 (Jump to date API endpoint) - self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index ea69b9bd9b50..63aab0babe66 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -1,5 +1,4 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +18,7 @@ import sys import threading from string import Template -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict import yaml from zope.interface import implementer @@ -41,7 +40,6 @@ from ._base import Config, ConfigError if TYPE_CHECKING: - from synapse.config.homeserver import HomeServerConfig from synapse.server import HomeServer DEFAULT_LOG_CONFIG = Template( @@ -143,13 +141,13 @@ class LoggingConfig(Config): section = "logging" - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): if config.get("log_file"): raise ConfigError(LOG_FILE_ERROR) self.log_config = self.abspath(config.get("log_config")) self.no_redirect_stdio = config.get("no_redirect_stdio", False) - def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: + def generate_config_section(self, config_dir_path, server_name, **kwargs): log_config = os.path.join(config_dir_path, server_name + ".log.config") return ( """\ @@ -163,14 +161,14 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str % locals() ) - def read_arguments(self, args: argparse.Namespace) -> None: + def read_arguments(self, args): if args.no_redirect_stdio is not None: self.no_redirect_stdio = args.no_redirect_stdio if args.log_file is not None: raise ConfigError(LOG_FILE_ERROR) @staticmethod - def add_arguments(parser: argparse.ArgumentParser) -> None: + def add_arguments(parser): logging_group = parser.add_argument_group("logging") logging_group.add_argument( "-n", @@ -199,9 +197,7 @@ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) -def _setup_stdlib_logging( - config: "HomeServerConfig", log_config_path: Optional[str], logBeginner: LogBeginner -) -> None: +def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None: """ Set up Python standard library logging. """ @@ -234,7 +230,7 @@ def _setup_stdlib_logging( log_metadata_filter = MetadataFilter({"server_name": config.server.server_name}) old_factory = logging.getLogRecordFactory() - def factory(*args: Any, **kwargs: Any) -> logging.LogRecord: + def factory(*args, **kwargs): record = old_factory(*args, **kwargs) log_context_filter.filter(record) log_metadata_filter.filter(record) @@ -301,7 +297,7 @@ def _load_logging_config(log_config_path: str) -> None: logging.config.dictConfig(log_config) -def _reload_logging_config(log_config_path: Optional[str]) -> None: +def _reload_logging_config(log_config_path): """ Reload the log configuration from the file and apply it. """ @@ -315,8 +311,8 @@ def _reload_logging_config(log_config_path: Optional[str]) -> None: def setup_logging( hs: "HomeServer", - config: "HomeServerConfig", - use_worker_options: bool = False, + config, + use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner, ) -> None: """ diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 79c400fe30b8..42f113cd249d 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import Counter -from typing import Any, Collection, Iterable, List, Mapping, Optional, Tuple, Type +from typing import Collection, Iterable, List, Mapping, Optional, Tuple, Type import attr @@ -36,7 +36,7 @@ class OIDCConfig(Config): section = "oidc" - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) if not self.oidc_providers: return @@ -66,7 +66,7 @@ def oidc_enabled(self) -> bool: # OIDC is enabled if we have a provider return bool(self.oidc_providers) - def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration # and login. @@ -495,89 +495,89 @@ def _parse_oidc_config_dict( ) -@attr.s(slots=True, frozen=True, auto_attribs=True) +@attr.s(slots=True, frozen=True) class OidcProviderClientSecretJwtKey: # a pem-encoded signing key - key: str + key = attr.ib(type=str) # properties to include in the JWT header - jwt_header: Mapping[str, str] + jwt_header = attr.ib(type=Mapping[str, str]) # properties to include in the JWT payload. - jwt_payload: Mapping[str, str] + jwt_payload = attr.ib(type=Mapping[str, str]) -@attr.s(slots=True, frozen=True, auto_attribs=True) +@attr.s(slots=True, frozen=True) class OidcProviderConfig: # a unique identifier for this identity provider. Used in the 'user_external_ids' # table, as well as the query/path parameter used in the login protocol. - idp_id: str + idp_id = attr.ib(type=str) # user-facing name for this identity provider. - idp_name: str + idp_name = attr.ib(type=str) # Optional MXC URI for icon for this IdP. - idp_icon: Optional[str] + idp_icon = attr.ib(type=Optional[str]) # Optional brand identifier for this IdP. - idp_brand: Optional[str] + idp_brand = attr.ib(type=Optional[str]) # whether the OIDC discovery mechanism is used to discover endpoints - discover: bool + discover = attr.ib(type=bool) # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to # discover the provider's endpoints. - issuer: str + issuer = attr.ib(type=str) # oauth2 client id to use - client_id: str + client_id = attr.ib(type=str) # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate # a secret. - client_secret: Optional[str] + client_secret = attr.ib(type=Optional[str]) # key to use to construct a JWT to use as a client secret. May be `None` if # `client_secret` is set. - client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] + client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey]) # auth method to use when exchanging the token. # Valid values are 'client_secret_basic', 'client_secret_post' and # 'none'. - client_auth_method: str + client_auth_method = attr.ib(type=str) # list of scopes to request - scopes: Collection[str] + scopes = attr.ib(type=Collection[str]) # the oauth2 authorization endpoint. Required if discovery is disabled. - authorization_endpoint: Optional[str] + authorization_endpoint = attr.ib(type=Optional[str]) # the oauth2 token endpoint. Required if discovery is disabled. - token_endpoint: Optional[str] + token_endpoint = attr.ib(type=Optional[str]) # the OIDC userinfo endpoint. Required if discovery is disabled and the # "openid" scope is not requested. - userinfo_endpoint: Optional[str] + userinfo_endpoint = attr.ib(type=Optional[str]) # URI where to fetch the JWKS. Required if discovery is disabled and the # "openid" scope is used. - jwks_uri: Optional[str] + jwks_uri = attr.ib(type=Optional[str]) # Whether to skip metadata verification - skip_verification: bool + skip_verification = attr.ib(type=bool) # Whether to fetch the user profile from the userinfo endpoint. Valid # values are: "auto" or "userinfo_endpoint". - user_profile_method: str + user_profile_method = attr.ib(type=str) # whether to allow a user logging in via OIDC to match a pre-existing account # instead of failing - allow_existing_users: bool + allow_existing_users = attr.ib(type=bool) # the class of the user mapping provider - user_mapping_provider_class: Type + user_mapping_provider_class = attr.ib(type=Type) # the config of the user mapping provider - user_mapping_provider_config: Any + user_mapping_provider_config = attr.ib() # required attributes to require in userinfo to allow login/registration - attribute_requirements: List[SsoAttributeRequirement] + attribute_requirements = attr.ib(type=List[SsoAttributeRequirement]) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7a059c6decc0..61e569d412e5 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -1,5 +1,4 @@ # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse -from typing import Optional from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError @@ -116,73 +113,32 @@ def read_config(self, config, **kwargs): self.session_lifetime = session_lifetime # The `refreshable_access_token_lifetime` applies for tokens that can be renewed - # using a refresh token, as per MSC2918. - # If it is `None`, the refresh token mechanism is disabled. + # using a refresh token, as per MSC2918. If it is `None`, the refresh + # token mechanism is disabled. + # + # Since it is incompatible with the `session_lifetime` mechanism, it is set to + # `None` by default if a `session_lifetime` is set. refreshable_access_token_lifetime = config.get( "refreshable_access_token_lifetime", - "5m", + "5m" if session_lifetime is None else None, ) if refreshable_access_token_lifetime is not None: refreshable_access_token_lifetime = self.parse_duration( refreshable_access_token_lifetime ) - self.refreshable_access_token_lifetime: Optional[ - int - ] = refreshable_access_token_lifetime + self.refreshable_access_token_lifetime = refreshable_access_token_lifetime if ( - self.session_lifetime is not None - and "refreshable_access_token_lifetime" in config + session_lifetime is not None + and refreshable_access_token_lifetime is not None ): - if self.session_lifetime < self.refreshable_access_token_lifetime: - raise ConfigError( - "Both `session_lifetime` and `refreshable_access_token_lifetime` " - "configuration options have been set, but `refreshable_access_token_lifetime` " - " exceeds `session_lifetime`!" - ) - - # The `nonrefreshable_access_token_lifetime` applies for tokens that can NOT be - # refreshed using a refresh token. - # If it is None, then these tokens last for the entire length of the session, - # which is infinite by default. - # The intention behind this configuration option is to help with requiring - # all clients to use refresh tokens, if the homeserver administrator requires. - nonrefreshable_access_token_lifetime = config.get( - "nonrefreshable_access_token_lifetime", - None, - ) - if nonrefreshable_access_token_lifetime is not None: - nonrefreshable_access_token_lifetime = self.parse_duration( - nonrefreshable_access_token_lifetime + raise ConfigError( + "The refresh token mechanism is incompatible with the " + "`session_lifetime` option. Consider disabling the " + "`session_lifetime` option or disabling the refresh token " + "mechanism by removing the `refreshable_access_token_lifetime` " + "option." ) - self.nonrefreshable_access_token_lifetime = nonrefreshable_access_token_lifetime - - if ( - self.session_lifetime is not None - and self.nonrefreshable_access_token_lifetime is not None - ): - if self.session_lifetime < self.nonrefreshable_access_token_lifetime: - raise ConfigError( - "Both `session_lifetime` and `nonrefreshable_access_token_lifetime` " - "configuration options have been set, but `nonrefreshable_access_token_lifetime` " - " exceeds `session_lifetime`!" - ) - - refresh_token_lifetime = config.get("refresh_token_lifetime") - if refresh_token_lifetime is not None: - refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) - self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime - - if ( - self.session_lifetime is not None - and self.refresh_token_lifetime is not None - ): - if self.session_lifetime < self.refresh_token_lifetime: - raise ConfigError( - "Both `session_lifetime` and `refresh_token_lifetime` " - "configuration options have been set, but `refresh_token_lifetime` " - " exceeds `session_lifetime`!" - ) # The fallback template used for authenticating using a registration token self.registration_token_template = self.read_template("registration_token.html") @@ -220,44 +176,6 @@ def generate_config_section(self, generate_secrets=False, **kwargs): # #session_lifetime: 24h - # Time that an access token remains valid for, if the session is - # using refresh tokens. - # For more information about refresh tokens, please see the manual. - # Note that this only applies to clients which advertise support for - # refresh tokens. - # - # Note also that this is calculated at login time and refresh time: - # changes are not applied to existing sessions until they are refreshed. - # - # By default, this is 5 minutes. - # - #refreshable_access_token_lifetime: 5m - - # Time that a refresh token remains valid for (provided that it is not - # exchanged for another one first). - # This option can be used to automatically log-out inactive sessions. - # Please see the manual for more information. - # - # Note also that this is calculated at login time and refresh time: - # changes are not applied to existing sessions until they are refreshed. - # - # By default, this is infinite. - # - #refresh_token_lifetime: 24h - - # Time that an access token remains valid for, if the session is NOT - # using refresh tokens. - # Please note that not all clients support refresh tokens, so setting - # this to a short value may be inconvenient for some users who will - # then be logged out frequently. - # - # Note also that this is calculated at login time: changes are not applied - # retrospectively to existing sessions for users that have already logged in. - # - # By default, this is infinite. - # - #nonrefreshable_access_token_lifetime: 24h - # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: @@ -451,7 +369,7 @@ def generate_config_section(self, generate_secrets=False, **kwargs): ) @staticmethod - def add_arguments(parser: argparse.ArgumentParser) -> None: + def add_arguments(parser): reg_group = parser.add_argument_group("registration") reg_group.add_argument( "--enable-registration", @@ -460,6 +378,6 @@ def add_arguments(parser: argparse.ArgumentParser) -> None: help="Enable registration for new users.", ) - def read_arguments(self, args: argparse.Namespace) -> None: + def read_arguments(self, args): if args.enable_registration is not None: self.enable_registration = strtobool(str(args.enable_registration)) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b129b9dd681c..69906a98d48a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -15,12 +15,11 @@ import logging import os from collections import namedtuple -from typing import Dict, List, Tuple +from typing import Dict, List from urllib.request import getproxies_environment # type: ignore from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import JsonDict from synapse.util.module_loader import load_module from ._base import Config, ConfigError @@ -58,9 +57,7 @@ ) -def parse_thumbnail_requirements( - thumbnail_sizes: List[JsonDict], -) -> Dict[str, Tuple[ThumbnailRequirement, ...]]: +def parse_thumbnail_requirements(thumbnail_sizes): """Takes a list of dictionaries with "width", "height", and "method" keys and creates a map from image media types to the thumbnail size, thumbnailing method, and thumbnail media type to precalculate @@ -72,7 +69,7 @@ def parse_thumbnail_requirements( Dictionary mapping from media type string to list of ThumbnailRequirement tuples. """ - requirements: Dict[str, List[ThumbnailRequirement]] = {} + requirements: Dict[str, List] = {} for size in thumbnail_sizes: width = size["width"] height = size["height"] diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index ec9d9f65e7b5..ba2b0905ffe8 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -1,5 +1,5 @@ # Copyright 2018 New Vector Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,10 @@ # limitations under the License. import logging -from typing import Any, List, Set +from typing import Any, List from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import JsonDict from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError @@ -34,7 +33,7 @@ ) -def _dict_merge(merge_dict: dict, into_dict: dict) -> None: +def _dict_merge(merge_dict, into_dict): """Do a deep merge of two dicts Recursively merges `merge_dict` into `into_dict`: @@ -44,8 +43,8 @@ def _dict_merge(merge_dict: dict, into_dict: dict) -> None: the value from `merge_dict`. Args: - merge_dict: dict to merge - into_dict: target dict to be modified + merge_dict (dict): dict to merge + into_dict (dict): target dict """ for k, v in merge_dict.items(): if k not in into_dict: @@ -65,7 +64,7 @@ def _dict_merge(merge_dict: dict, into_dict: dict) -> None: class SAML2Config(Config): section = "saml2" - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): self.saml2_enabled = False saml2_config = config.get("saml2_config") @@ -184,8 +183,8 @@ def read_config(self, config, **kwargs) -> None: ) def _default_saml_config_dict( - self, required_attributes: Set[str], optional_attributes: Set[str] - ) -> JsonDict: + self, required_attributes: set, optional_attributes: set + ): """Generate a configuration dictionary with required and optional attributes that will be needed to process new user registration @@ -196,7 +195,7 @@ def _default_saml_config_dict( additional information to Synapse user accounts, but are not required Returns: - A SAML configuration dictionary + dict: A SAML configuration dictionary """ import saml2 @@ -223,7 +222,7 @@ def _default_saml_config_dict( }, } - def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ ## Single sign-on integration ## diff --git a/synapse/config/server.py b/synapse/config/server.py index ba5b95426338..8445e9dd0509 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import itertools import logging import os.path @@ -28,7 +27,6 @@ from twisted.conch.ssh.keys import Key from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.types import JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_server_name @@ -1225,7 +1223,7 @@ def generate_config_section( % locals() ) - def read_arguments(self, args: argparse.Namespace) -> None: + def read_arguments(self, args): if args.manhole is not None: self.manhole = args.manhole if args.daemonize is not None: @@ -1234,7 +1232,7 @@ def read_arguments(self, args: argparse.Namespace) -> None: self.print_pidfile = args.print_pidfile @staticmethod - def add_arguments(parser: argparse.ArgumentParser) -> None: + def add_arguments(parser): server_group = parser.add_argument_group("server") server_group.add_argument( "-D", @@ -1276,16 +1274,14 @@ def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]: ) -def is_threepid_reserved( - reserved_threepids: List[JsonDict], threepid: JsonDict -) -> bool: +def is_threepid_reserved(reserved_threepids, threepid): """Check the threepid against the reserved threepid config Args: - reserved_threepids: List of reserved threepids - threepid: The threepid to test for + reserved_threepids([dict]) - list of reserved threepids + threepid(dict) - The threepid to test for Returns: - Is the threepid undertest reserved_user + boolean Is the threepid undertest reserved_user """ for tp in reserved_threepids: @@ -1294,9 +1290,7 @@ def is_threepid_reserved( return False -def read_gc_thresholds( - thresholds: Optional[List[Any]], -) -> Optional[Tuple[int, int, int]]: +def read_gc_thresholds(thresholds): """Reads the three integer thresholds for garbage collection. Ensures that the thresholds are integers if thresholds are supplied. """ diff --git a/synapse/config/sso.py b/synapse/config/sso.py index e4a424326124..60aacb13ea40 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -1,4 +1,4 @@ -# Copyright 2020-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,13 +29,13 @@ ---------------------------------------------------------------------------------------""" -@attr.s(frozen=True, auto_attribs=True) +@attr.s(frozen=True) class SsoAttributeRequirement: """Object describing a single requirement for SSO attributes.""" - attribute: str + attribute = attr.ib(type=str) # If a value is not given, than the attribute must simply exist. - value: Optional[str] + value = attr.ib(type=Optional[str]) JSON_SCHEMA = { "type": "object", @@ -49,7 +49,7 @@ class SSOConfig(Config): section = "sso" - def read_config(self, config, **kwargs) -> None: + def read_config(self, config, **kwargs): sso_config: Dict[str, Any] = config.get("sso") or {} # The sso-specific template_dir @@ -106,7 +106,7 @@ def read_config(self, config, **kwargs) -> None: ) self.sso_client_whitelist.append(login_fallback_url) - def generate_config_section(self, **kwargs) -> str: + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 576f519188bb..450799203112 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -1,5 +1,4 @@ # Copyright 2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse from typing import List, Union import attr @@ -345,7 +343,7 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): #worker_replication_secret: "" """ - def read_arguments(self, args: argparse.Namespace) -> None: + def read_arguments(self, args): # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running # on workers so we also have to override them when command line options diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 993b04099e28..4cda439ad927 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -667,25 +667,21 @@ async def get_server_verify_key_v2_indirect( perspective_name, ) - request: JsonDict = {} - for queue_value in keys_to_fetch: - # there may be multiple requests for each server, so we have to merge - # them intelligently. - request_for_server = { - key_id: { - "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, - } - for key_id in queue_value.key_ids - } - request.setdefault(queue_value.server_name, {}).update(request_for_server) - - logger.debug("Request to notary server %s: %s", perspective_name, request) - try: query_response = await self.client.post_json( destination=perspective_name, path="/_matrix/key/v2/query", - data={"server_keys": request}, + data={ + "server_keys": { + queue_value.server_name: { + key_id: { + "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, + } + for key_id in queue_value.key_ids + } + for queue_value in keys_to_fetch + } + }, ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -693,10 +689,6 @@ async def get_server_verify_key_v2_indirect( except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) - logger.debug( - "Response from notary server %s: %s", perspective_name, query_response - ) - keys: Dict[str, Dict[str, FetchKeyResult]] = {} added_keys: List[Tuple[str, str, FetchKeyResult]] = [] diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f251402ed8f2..d7527008c443 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -322,11 +322,6 @@ async def _fill_out_state(self) -> None: attributes by loading from the database. """ if self.state_group is None: - # No state group means the event is an outlier. Usually the state_ids dicts are also - # pre-set to empty dicts, but they get reset when the context is serialized, so set - # them to empty dicts again here. - self._current_state_ids = {} - self._prev_state_ids = {} return current_state_ids = await self._storage.state.get_state_ids_for_group( diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 84ef69df679b..e5967c995e8a 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -306,7 +306,6 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, - *, as_client_event: bool = True, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, token_id: Optional[str] = None, @@ -394,8 +393,7 @@ async def serialize_event( self, event: Union[JsonDict, EventBase], time_now: int, - *, - bundle_aggregations: bool = True, + bundle_relations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -403,9 +401,8 @@ async def serialize_event( Args: event: The event being serialized. time_now: The current time in milliseconds - bundle_aggregations: Whether to include the bundled aggregations for this - event. Only applies to non-state events. (State events never include - bundled aggregations.) + bundle_relations: Whether to include the bundled relations for this + event. **kwargs: Arguments to pass to `serialize_event` Returns: @@ -417,27 +414,20 @@ async def serialize_event( serialized_event = serialize_event(event, time_now, **kwargs) - # Check if there are any bundled aggregations to include with the event. - # - # Do not bundle aggregations if any of the following at true: - # - # * Support is disabled via the configuration or the caller. - # * The event is a state event. - # * The event has been redacted. - if ( - self._msc1849_enabled - and bundle_aggregations - and not event.is_state() - and not event.internal_metadata.is_redacted() + # If MSC1849 is enabled then we need to look if there are any relations + # we need to bundle in with the event. + # Do not bundle relations if the event has been redacted + if not event.internal_metadata.is_redacted() and ( + self._msc1849_enabled and bundle_relations ): - await self._injected_bundled_aggregations(event, time_now, serialized_event) + await self._injected_bundled_relations(event, time_now, serialized_event) return serialized_event - async def _injected_bundled_aggregations( + async def _injected_bundled_relations( self, event: EventBase, time_now: int, serialized_event: JsonDict ) -> None: - """Potentially injects bundled aggregations into the unsigned portion of the serialized event. + """Potentially injects bundled relations into the unsigned portion of the serialized event. Args: event: The event being serialized. @@ -445,28 +435,20 @@ async def _injected_bundled_aggregations( serialized_event: The serialized event which may be modified. """ - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return - event_id = event.event_id - # The bundled aggregations to include. - aggregations = {} + # The bundled relations to include. + relations = {} annotations = await self.store.get_aggregation_groups_for_event(event_id) if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + relations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() + relations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: @@ -492,7 +474,7 @@ async def _injected_bundled_aggregations( else: serialized_event["content"].pop("m.relates_to", None) - aggregations[RelationTypes.REPLACE] = { + relations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, @@ -505,19 +487,17 @@ async def _injected_bundled_aggregations( latest_thread_event, ) = await self.store.get_thread_summary(event_id) if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. + relations[RelationTypes.THREAD] = { + # Don't bundle relations as this could recurse forever. "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False + latest_thread_event, time_now, bundle_relations=False ), "count": thread_count, } - # If any bundled aggregations were found, include them. - if aggregations: - serialized_event["unsigned"].setdefault("m.relations", {}).update( - aggregations - ) + # If any bundled relations were found, include them. + if relations: + serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index fee1477ab684..3b85b135e0d3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -128,7 +128,7 @@ def __init__(self, hs: "HomeServer"): reset_expiry_on_get=False, ) - def _clear_tried_cache(self) -> None: + def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -800,7 +800,7 @@ async def send_join( no servers successfully handle the request. """ - async def send_request(destination: str) -> SendJoinResult: + async def send_request(destination) -> SendJoinResult: response = await self._do_send_join(room_version, destination, pdu) # If an event was returned (and expected to be returned): @@ -1395,28 +1395,11 @@ async def get_room_hierarchy( async def send_request( destination: str, ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: - try: - res = await self.transport_layer.get_room_hierarchy( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - ) - except HttpResponseException as e: - # If an error is received that is due to an unrecognised endpoint, - # fallback to the unstable endpoint. Otherwise consider it a - # legitmate error and raise. - if not self._is_unknown_endpoint(e): - raise - - logger.debug( - "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API" - ) - - res = await self.transport_layer.get_room_hierarchy_unstable( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - ) + res = await self.transport_layer.get_room_hierarchy( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) room = res.get("room") if not isinstance(room, dict): @@ -1466,10 +1449,6 @@ async def send_request( if e.code != 502: raise - logger.debug( - "Couldn't fetch room hierarchy, falling back to the spaces API" - ) - # Fallback to the old federation API and translate the results if # no servers implement the new API. # @@ -1517,83 +1496,6 @@ async def send_request( self._get_room_hierarchy_cache[(room_id, suggested_only)] = result return result - async def timestamp_to_event( - self, destination: str, room_id: str, timestamp: int, direction: str - ) -> "TimestampToEventResponse": - """ - Calls a remote federating server at `destination` asking for their - closest event to the given timestamp in the given direction. Also - validates the response to always return the expected keys or raises an - error. - - Args: - destination: Domain name of the remote homeserver - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - A parsed TimestampToEventResponse including the closest event_id - and origin_server_ts - - Raises: - Various exceptions when the request fails - InvalidResponseError when the response does not have the correct - keys or wrong types - """ - remote_response = await self.transport_layer.timestamp_to_event( - destination, room_id, timestamp, direction - ) - - if not isinstance(remote_response, dict): - raise InvalidResponseError( - "Response must be a JSON dictionary but received %r" % remote_response - ) - - try: - return TimestampToEventResponse.from_json_dict(remote_response) - except ValueError as e: - raise InvalidResponseError(str(e)) - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class TimestampToEventResponse: - """Typed response dictionary for the federation /timestamp_to_event endpoint""" - - event_id: str - origin_server_ts: int - - # the raw data, including the above keys - data: JsonDict - - @classmethod - def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": - """Parsed response from the federation /timestamp_to_event endpoint - - Args: - d: JSON object response to be parsed - - Raises: - ValueError if d does not the correct keys or they are the wrong types - """ - - event_id = d.get("event_id") - if not isinstance(event_id, str): - raise ValueError( - "Invalid response: 'event_id' must be a str but received %r" % event_id - ) - - origin_server_ts = d.get("origin_server_ts") - if not isinstance(origin_server_ts, int): - raise ValueError( - "Invalid response: 'origin_server_ts' must be a int but received %r" - % origin_server_ts - ) - - return cls(event_id, origin_server_ts, d) - @attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e37e76206ac..9a8758e9a6d3 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,6 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019-2021 Matrix.org Federation C.I.C +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -110,7 +110,6 @@ def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_federation_handler() - self.storage = hs.get_storage() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -201,48 +200,6 @@ async def on_backfill_request( return 200, res - async def on_timestamp_to_event_request( - self, origin: str, room_id: str, timestamp: int, direction: str - ) -> Tuple[int, Dict[str, Any]]: - """When we receive a federated `/timestamp_to_event` request, - handle all of the logic for validating and fetching the event. - - Args: - origin: The server we received the event from - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - Tuple indicating the response status code and dictionary response - body including `event_id`. - """ - with (await self._server_linearizer.queue((origin, room_id))): - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, room_id) - - # We only try to fetch data from the local database - event_id = await self.store.get_event_id_for_timestamp( - room_id, timestamp, direction - ) - if event_id: - event = await self.store.get_event( - event_id, allow_none=False, allow_rejected=False - ) - - return 200, { - "event_id": event_id, - "origin_server_ts": event.origin_server_ts, - } - - raise SynapseError( - 404, - "Unable to find event from %s in direction %s" % (timestamp, direction), - errcode=Codes.NOT_FOUND, - ) - async def on_incoming_transaction( self, origin: str, @@ -450,7 +407,7 @@ async def _handle_pdus_in_txn( # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - async def process_pdus_for_room(room_id: str) -> None: + async def process_pdus_for_room(room_id: str): with nested_logging_context(room_id): logger.debug("Processing PDUs for %s", room_id) @@ -547,7 +504,7 @@ async def on_room_state_request( async def on_state_ids_request( self, origin: str, room_id: str, event_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Dict[str, Any]]: if not event_id: raise NotImplementedError("Specify an event") @@ -567,9 +524,7 @@ async def on_state_ids_request( return 200, resp - async def _on_state_ids_request_compute( - self, room_id: str, event_id: str - ) -> JsonDict: + async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} @@ -658,11 +613,8 @@ async def on_send_join_request( state = await self.store.get_events(state_ids) time_now = self._clock.time_msec() - event_json = event.get_pdu_json() return { - # TODO Remove the unstable prefix when servers have updated. - "org.matrix.msc3083.v2.event": event_json, - "event": event_json, + "org.matrix.msc3083.v2.event": event.get_pdu_json(), "state": [p.get_pdu_json(time_now) for p in state.values()], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], } diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 523ab1c51ed1..4fead6ca2954 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -1,5 +1,4 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +23,6 @@ from synapse.federation.units import Transaction from synapse.logging.utils import log_function -from synapse.storage.databases.main import DataStore from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -33,7 +31,7 @@ class TransactionActions: """Defines persistence actions that relate to handling Transactions.""" - def __init__(self, datastore: DataStore): + def __init__(self, datastore): self.store = datastore @log_function diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 63289a5a334f..1fbf325fdc9a 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -1,5 +1,4 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -351,7 +350,7 @@ class BaseFederationRow: TypeId = "" # Unique string that ids the type. Must be overridden in sub classes. @staticmethod - def from_data(data: JsonDict) -> "BaseFederationRow": + def from_data(data): """Parse the data from the federation stream into a row. Args: @@ -360,7 +359,7 @@ def from_data(data: JsonDict) -> "BaseFederationRow": """ raise NotImplementedError() - def to_data(self) -> JsonDict: + def to_data(self): """Serialize this row to be sent over the federation stream. Returns: @@ -369,7 +368,7 @@ def to_data(self) -> JsonDict: """ raise NotImplementedError() - def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: + def add_to_buffer(self, buff): """Add this row to the appropriate field in the buffer ready for this to be sent over federation. @@ -392,15 +391,15 @@ class PresenceDestinationsRow( TypeId = "pd" @staticmethod - def from_data(data: JsonDict) -> "PresenceDestinationsRow": + def from_data(data): return PresenceDestinationsRow( state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] ) - def to_data(self) -> JsonDict: + def to_data(self): return {"state": self.state.as_dict(), "dests": self.destinations} - def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: + def add_to_buffer(self, buff): buff.presence_destinations.append((self.state, self.destinations)) @@ -418,13 +417,13 @@ class KeyedEduRow( TypeId = "k" @staticmethod - def from_data(data: JsonDict) -> "KeyedEduRow": + def from_data(data): return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) - def to_data(self) -> JsonDict: + def to_data(self): return {"key": self.key, "edu": self.edu.get_internal_dict()} - def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: + def add_to_buffer(self, buff): buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu @@ -434,13 +433,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu TypeId = "e" @staticmethod - def from_data(data: JsonDict) -> "EduRow": + def from_data(data): return EduRow(Edu(**data)) - def to_data(self) -> JsonDict: + def to_data(self): return self.edu.get_internal_dict() - def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: + def add_to_buffer(self, buff): buff.edus.setdefault(self.edu.destination, []).append(self.edu) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 391b30fbb559..afe35e72b6ba 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -1,6 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +14,7 @@ # limitations under the License. import datetime import logging -from types import TracebackType -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple import attr from prometheus_client import Counter @@ -215,7 +213,7 @@ def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu self.attempt_new_transaction() - def send_edu(self, edu: Edu) -> None: + def send_edu(self, edu) -> None: self._pending_edus.append(edu) self.attempt_new_transaction() @@ -703,12 +701,7 @@ async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: return self._pdus, pending_edus - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> None: + async def __aexit__(self, exc_type, exc, tb): if exc_type is not None: # Failed to send transaction, so we bail out. return diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9fc4c31c93f6..10b5aa5af824 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -21,7 +21,6 @@ Callable, Collection, Dict, - Generator, Iterable, List, Mapping, @@ -149,42 +148,6 @@ async def backfill( destination, path=path, args=args, try_trailing_slash_on_400=True ) - @log_function - async def timestamp_to_event( - self, destination: str, room_id: str, timestamp: int, direction: str - ) -> Union[JsonDict, List]: - """ - Calls a remote federating server at `destination` asking for their - closest event to the given timestamp in the given direction. - - Args: - destination: Domain name of the remote homeserver - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - Response dict received from the remote homeserver. - - Raises: - Various exceptions when the request fails - """ - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, - "/org.matrix.msc3030/timestamp_to_event/%s", - room_id, - ) - - args = {"ts": [str(timestamp)], "dir": [direction]} - - remote_response = await self.client.get_json( - destination, path=path, args=args, try_trailing_slash_on_400=True - ) - - return remote_response - @log_function async def send_transaction( self, @@ -236,16 +199,11 @@ async def send_transaction( @log_function async def make_query( - self, - destination: str, - query_type: str, - args: dict, - retry_on_dns_fail: bool, - ignore_backoff: bool = False, - ) -> JsonDict: + self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False + ): path = _create_v1_path("/query/%s", query_type) - return await self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=args, @@ -254,6 +212,8 @@ async def make_query( ignore_backoff=ignore_backoff, ) + return content + @log_function async def make_membership_event( self, @@ -1232,24 +1192,10 @@ async def get_space_summary( ) async def get_room_hierarchy( - self, destination: str, room_id: str, suggested_only: bool - ) -> JsonDict: - """ - Args: - destination: The remote server - room_id: The room ID to ask about. - suggested_only: if True, only suggested rooms will be returned - """ - path = _create_v1_path("/hierarchy/%s", room_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"suggested_only": "true" if suggested_only else "false"}, - ) - - async def get_room_hierarchy_unstable( - self, destination: str, room_id: str, suggested_only: bool + self, + destination: str, + room_id: str, + suggested_only: bool, ) -> JsonDict: """ Args: @@ -1321,7 +1267,7 @@ class SendJoinResponse: @ijson.coroutine -def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: +def _event_parser(event_dict: JsonDict): """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs to add them to a given dictionary. """ @@ -1332,9 +1278,7 @@ def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None @ijson.coroutine -def _event_list_parser( - room_version: RoomVersion, events: List[EventBase] -) -> Generator[None, JsonDict, None]: +def _event_list_parser(room_version: RoomVersion, events: List[EventBase]): """Helper function for use with `ijson.items_coro` to parse an array of events and add them to the given list. """ @@ -1373,26 +1317,15 @@ def __init__(self, room_version: RoomVersion, v1_api: bool): prefix + "auth_chain.item", use_float=True, ) - # TODO Remove the unstable prefix when servers have updated. - # - # By re-using the same event dictionary this will cause the parsing of - # org.matrix.msc3083.v2.event and event to stomp over each other. - # Generally this should be fine. - self._coro_unstable_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "org.matrix.msc3083.v2.event", - use_float=True, - ) self._coro_event = ijson.kvitems_coro( _event_parser(self._response.event_dict), - prefix + "event", + prefix + "org.matrix.msc3083.v2.event", use_float=True, ) def write(self, data: bytes) -> int: self._coro_state.send(data) self._coro_auth.send(data) - self._coro_unstable_event.send(data) self._coro_event.send(data) return len(data) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 77b936361a4a..c32539bf5a52 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -22,10 +22,7 @@ Authenticator, BaseFederationServlet, ) -from synapse.federation.transport.server.federation import ( - FEDERATION_SERVLET_CLASSES, - FederationTimestampLookupServlet, -) +from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES from synapse.federation.transport.server.groups_server import ( GROUP_SERVER_SERVLET_CLASSES, @@ -302,7 +299,7 @@ def register_servlets( authenticator: Authenticator, ratelimiter: FederationRateLimiter, servlet_groups: Optional[Iterable[str]] = None, -) -> None: +): """Initialize and register servlet classes. Will by default register all servlets. For custom behaviour, pass in @@ -327,13 +324,6 @@ def register_servlets( ) for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: - # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled - if ( - servletclass == FederationTimestampLookupServlet - and not hs.config.experimental.msc3030_enabled - ): - continue - servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dc39e3537bf6..cef65929c529 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -15,13 +15,10 @@ import functools import logging import re -from typing import Any, Awaitable, Callable, Optional, Tuple, cast from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_V1_PREFIX -from synapse.http.server import HttpServer, ServletCallback from synapse.http.servlet import parse_json_object_from_request -from synapse.http.site import SynapseRequest from synapse.logging import opentracing from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -32,7 +29,6 @@ whitelisted_homeserver, ) from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import parse_and_validate_server_name @@ -63,11 +59,9 @@ def __init__(self, hs: HomeServer): self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets - async def authenticate_request( - self, request: SynapseRequest, content: Optional[JsonDict] - ) -> str: + async def authenticate_request(self, request, content): now = self._clock.time_msec() - json_request: JsonDict = { + json_request = { "method": request.method.decode("ascii"), "uri": request.uri.decode("ascii"), "destination": self.server_name, @@ -120,7 +114,7 @@ async def authenticate_request( return origin - async def _reset_retry_timings(self, origin: str) -> None: + async def _reset_retry_timings(self, origin): try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) @@ -139,14 +133,14 @@ async def _reset_retry_timings(self, origin: str) -> None: logger.exception("Error resetting retry timings on %s", origin) -def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]: +def _parse_auth_header(header_bytes): """Parse an X-Matrix auth header Args: - header_bytes: header value + header_bytes (bytes): header value Returns: - origin, key id, signature. + Tuple[str, str, str]: origin, key id, signature. Raises: AuthenticationError if the header could not be parsed @@ -154,9 +148,9 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]: try: header_str = header_bytes.decode("utf-8") params = header_str.split(" ")[1].split(",") - param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)} + param_dict = dict(kv.split("=") for kv in params) - def strip_quotes(value: str) -> str: + def strip_quotes(value): if value.startswith('"'): return value[1:-1] else: @@ -239,25 +233,23 @@ def __init__( self.ratelimiter = ratelimiter self.server_name = server_name - def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback: + def _wrap(self, func): authenticator = self.authenticator ratelimiter = self.ratelimiter @functools.wraps(func) - async def new_func( - request: SynapseRequest, *args: Any, **kwargs: str - ) -> Optional[Tuple[int, Any]]: + async def new_func(request, *args, **kwargs): """A callback which can be passed to HttpServer.RegisterPaths Args: - request: + request (twisted.web.http.Request): *args: unused? - **kwargs: the dict mapping keys to path components as specified - in the path match regexp. + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. Returns: - (response code, response object) as returned by the callback method. - None if the request has already been handled. + Tuple[int, object]|None: (response code, response object) as returned by + the callback method. None if the request has already been handled. """ content = None if request.method in [b"PUT", b"POST"]: @@ -265,9 +257,7 @@ async def new_func( content = parse_json_object_from_request(request) try: - origin: Optional[str] = await authenticator.authenticate_request( - request, content - ) + origin = await authenticator.authenticate_request(request, content) except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: @@ -311,7 +301,7 @@ async def new_func( "client disconnected before we started processing " "request" ) - return None + return -1, None response = await func( origin, content, request.args, *args, **kwargs ) @@ -322,9 +312,9 @@ async def new_func( return response - return cast(ServletCallback, new_func) + return new_func - def register(self, server: HttpServer) -> None: + def register(self, server): pattern = re.compile("^" + self.PREFIX + self.PATH + "$") for method in ("GET", "PUT", "POST"): diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 77bfd88ad052..2fdf6cc99e49 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -174,46 +174,6 @@ async def on_GET( return await self.handler.on_backfill_request(origin, room_id, versions, limit) -class FederationTimestampLookupServlet(BaseFederationServerServlet): - """ - API endpoint to fetch the `event_id` of the closest event to the given - timestamp (`ts` query parameter) in the given direction (`dir` query - parameter). - - Useful for other homeservers when they're unable to find an event locally. - - `ts` is a timestamp in milliseconds where we will find the closest event in - the given direction. - - `dir` can be `f` or `b` to indicate forwards and backwards in time from the - given timestamp. - - GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= - { - "event_id": ... - } - """ - - PATH = "/timestamp_to_event/(?P[^/]*)/?" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - timestamp = parse_integer_from_args(query, "ts", required=True) - direction = parse_string_from_args( - query, "dir", default="f", allowed_values=["f", "b"], required=True - ) - - return await self.handler.on_timestamp_to_event_request( - origin, room_id, timestamp, direction - ) - - class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" @@ -651,6 +611,7 @@ async def on_POST( class FederationRoomHierarchyServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/hierarchy/(?P[^/]*)" def __init__( @@ -676,10 +637,6 @@ async def on_GET( ) -class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - - class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -723,7 +680,6 @@ async def on_GET( FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, - FederationTimestampLookupServlet, FederationQueryServlet, FederationMakeJoinServlet, FederationMakeLeaveServlet, @@ -745,7 +701,6 @@ async def on_GET( RoomComplexityServlet, FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, - FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 61607cf2bad7..4b66a9862f1b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,7 +18,6 @@ import unicodedata import urllib.parse from binascii import crc32 -from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -39,7 +38,6 @@ import bcrypt import pymacaroons import unpaddedbase64 -from pymacaroons.exceptions import MacaroonVerificationFailedException from twisted.web.server import Request @@ -183,11 +181,8 @@ class LoginTokenAttributes: user_id = attr.ib(type=str) + # the SSO Identity Provider that the user authenticated with, to get this token auth_provider_id = attr.ib(type=str) - """The SSO Identity Provider that the user authenticated with, to get this token.""" - - auth_provider_session_id = attr.ib(type=Optional[str]) - """The session ID advertised by the SSO Identity Provider.""" class AuthHandler: @@ -761,109 +756,53 @@ def _auth_dict_for_flows( async def refresh_token( self, refresh_token: str, - access_token_valid_until_ms: Optional[int], - refresh_token_valid_until_ms: Optional[int], - ) -> Tuple[str, str, Optional[int]]: + valid_until_ms: Optional[int], + ) -> Tuple[str, str]: """ Consumes a refresh token and generate both a new access token and a new refresh token from it. The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token. - The lifetime of both the access token and refresh token will be capped so that they - do not exceed the session's ultimate expiry time, if applicable. - Args: refresh_token: The token to consume. - access_token_valid_until_ms: The expiration timestamp of the new access token. - None if the access token does not expire. - refresh_token_valid_until_ms: The expiration timestamp of the new refresh token. - None if the refresh token does not expire. + valid_until_ms: The expiration timestamp of the new access token. + Returns: - A tuple containing: - - the new access token - - the new refresh token - - the actual expiry time of the access token, which may be earlier than - `access_token_valid_until_ms`. + A tuple containing the new access token and refresh token """ # Verify the token signature first before looking up the token if not self._verify_refresh_token(refresh_token): - raise SynapseError( - HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN - ) + raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN) existing_token = await self.store.lookup_refresh_token(refresh_token) if existing_token is None: - raise SynapseError( - HTTPStatus.UNAUTHORIZED, - "refresh token does not exist", - Codes.UNKNOWN_TOKEN, - ) + raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN) if ( existing_token.has_next_access_token_been_used or existing_token.has_next_refresh_token_been_refreshed ): raise SynapseError( - HTTPStatus.FORBIDDEN, - "refresh token isn't valid anymore", - Codes.FORBIDDEN, + 403, "refresh token isn't valid anymore", Codes.FORBIDDEN ) - now_ms = self._clock.time_msec() - - if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: - - raise SynapseError( - HTTPStatus.FORBIDDEN, - "The supplied refresh token has expired", - Codes.FORBIDDEN, - ) - - if existing_token.ultimate_session_expiry_ts is not None: - # This session has a bounded lifetime, even across refreshes. - - if access_token_valid_until_ms is not None: - access_token_valid_until_ms = min( - access_token_valid_until_ms, - existing_token.ultimate_session_expiry_ts, - ) - else: - access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts - - if refresh_token_valid_until_ms is not None: - refresh_token_valid_until_ms = min( - refresh_token_valid_until_ms, - existing_token.ultimate_session_expiry_ts, - ) - else: - refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts - if existing_token.ultimate_session_expiry_ts < now_ms: - raise SynapseError( - HTTPStatus.FORBIDDEN, - "The session has expired and can no longer be refreshed", - Codes.FORBIDDEN, - ) - ( new_refresh_token, new_refresh_token_id, ) = await self.create_refresh_token_for_user_id( - user_id=existing_token.user_id, - device_id=existing_token.device_id, - expiry_ts=refresh_token_valid_until_ms, - ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts, + user_id=existing_token.user_id, device_id=existing_token.device_id ) access_token = await self.create_access_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id, - valid_until_ms=access_token_valid_until_ms, + valid_until_ms=valid_until_ms, refresh_token_id=new_refresh_token_id, ) await self.store.replace_refresh_token( existing_token.token_id, new_refresh_token_id ) - return access_token, new_refresh_token, access_token_valid_until_ms + return access_token, new_refresh_token def _verify_refresh_token(self, token: str) -> bool: """ @@ -897,8 +836,6 @@ async def create_refresh_token_for_user_id( self, user_id: str, device_id: str, - expiry_ts: Optional[int], - ultimate_session_expiry_ts: Optional[int], ) -> Tuple[str, int]: """ Creates a new refresh token for the user with the given user ID. @@ -906,13 +843,6 @@ async def create_refresh_token_for_user_id( Args: user_id: canonical user ID device_id: the device ID to associate with the token. - expiry_ts (milliseconds since the epoch): Time after which the - refresh token cannot be used. - If None, the refresh token never expires until it has been used. - ultimate_session_expiry_ts (milliseconds since the epoch): - Time at which the session will end and can not be extended any - further. - If None, the session can be refreshed indefinitely. Returns: The newly created refresh token and its ID in the database @@ -922,8 +852,6 @@ async def create_refresh_token_for_user_id( user_id=user_id, token=refresh_token, device_id=device_id, - expiry_ts=expiry_ts, - ultimate_session_expiry_ts=ultimate_session_expiry_ts, ) return refresh_token, refresh_token_id @@ -1654,7 +1582,6 @@ async def complete_sso_login( client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, - auth_provider_session_id: Optional[str] = None, ) -> None: """Having figured out a mxid for this user, complete the HTTP request @@ -1670,7 +1597,6 @@ async def complete_sso_login( during successful login. Must be JSON serializable. new_user: True if we should use wording appropriate to a user who has just registered. - auth_provider_session_id: The session ID from the SSO IdP received during login. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1691,7 +1617,6 @@ async def complete_sso_login( extra_attributes, new_user=new_user, user_profile_data=profile, - auth_provider_session_id=auth_provider_session_id, ) def _complete_sso_login( @@ -1703,7 +1628,6 @@ def _complete_sso_login( extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, - auth_provider_session_id: Optional[str] = None, ) -> None: """ The synchronous portion of complete_sso_login. @@ -1725,9 +1649,7 @@ def _complete_sso_login( # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, + registered_user_id, auth_provider_id=auth_provider_id ) # Append the login token to the original redirect URL (i.e. with its query @@ -1832,7 +1754,6 @@ def generate_short_term_login_token( self, user_id: str, auth_provider_id: str, - auth_provider_session_id: Optional[str] = None, duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1841,10 +1762,6 @@ def generate_short_term_login_token( expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) - if auth_provider_session_id is not None: - macaroon.add_first_party_caveat( - "auth_provider_session_id = %s" % (auth_provider_session_id,) - ) return macaroon.serialize() def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: @@ -1866,28 +1783,15 @@ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: user_id = get_value_from_macaroon(macaroon, "user_id") auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") - auth_provider_session_id: Optional[str] = None - try: - auth_provider_session_id = get_value_from_macaroon( - macaroon, "auth_provider_session_id" - ) - except MacaroonVerificationFailedException: - pass - v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") v.satisfy_exact("type = login") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) satisfy_expiry(v, self.hs.get_clock().time_msec) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - return LoginTokenAttributes( - user_id=user_id, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) + return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 82ee11e921e6..68b446eb66c8 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -301,8 +301,6 @@ async def check_device_registered( user_id: str, device_id: Optional[str], initial_device_display_name: Optional[str] = None, - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, ) -> str: """ If the given device has not been registered, register it with the @@ -314,8 +312,6 @@ async def check_device_registered( user_id: @user:id device_id: device id supplied by client initial_device_display_name: device display name from client - auth_provider_id: The SSO IdP the user used, if any. - auth_provider_session_id: The session ID (sid) got from the SSO IdP. Returns: device id (generated if none was supplied) """ @@ -327,8 +323,6 @@ async def check_device_registered( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [device_id]) @@ -343,8 +337,6 @@ async def check_device_registered( user_id=user_id, device_id=new_device_id, initial_device_display_name=initial_device_display_name, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [new_device_id]) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 32b0254c5f08..b4ff935546c8 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -122,8 +122,9 @@ async def get_stream( events, time_now, as_client_event=as_client_event, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, + # We don't bundle "live" events, as otherwise clients + # will end up double counting annotations. + bundle_relations=False, ) chunk = { diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1ea837d08211..3112cc88b1cc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -68,37 +68,6 @@ logger = logging.getLogger(__name__) -def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: - """Get joined domains from state - - Args: - state: State map from type/state key to event. - - Returns: - Returns a list of servers with the lowest depth of their joins. - Sorted by lowest depth first. - """ - joined_users = [ - (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() - if e_type == EventTypes.Member and event.membership == Membership.JOIN - ] - - joined_domains: Dict[str, int] = {} - for u, d in joined_users: - try: - dom = get_domain_from_id(u) - old_d = joined_domains.get(dom) - if old_d: - joined_domains[dom] = min(d, old_d) - else: - joined_domains[dom] = d - except Exception: - pass - - return sorted(joined_domains.items(), key=lambda d: d[1]) - - class FederationHandler: """Handles general incoming federation requests @@ -299,6 +268,36 @@ async def _maybe_backfill_inner( curr_state = await self.state_handler.get_current_state(room_id) + def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: + """Get joined domains from state + + Args: + state: State map from type/state key to event. + + Returns: + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. + """ + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member and event.membership == Membership.JOIN + ] + + joined_domains: Dict[str, int] = {} + for u, d in joined_users: + try: + dom = get_domain_from_id(u) + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except Exception: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + curr_domains = get_domains_from_state(curr_state) likely_domains = [ diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9cd21e7f2b3c..d4e45561555c 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -165,11 +165,7 @@ async def handle_room(event: RoomsForUser) -> None: invite_event = await self.store.get_event(event.event_id) d["invite"] = await self._event_serializer.serialize_event( - invite_event, - time_now, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, - as_client_event=as_client_event, + invite_event, time_now, as_client_event ) rooms_ret.append(d) @@ -220,11 +216,7 @@ async def handle_room(event: RoomsForUser) -> None: d["messages"] = { "chunk": ( await self._event_serializer.serialize_events( - messages, - time_now=time_now, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, - as_client_event=as_client_event, + messages, time_now=time_now, as_client_event=as_client_event ) ), "start": await start_token.to_string(self.store), @@ -234,8 +226,6 @@ async def handle_room(event: RoomsForUser) -> None: d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - # Don't bundle aggregations as this is a deprecated API. - bundle_aggregations=False, as_client_event=as_client_event, ) @@ -376,18 +366,14 @@ async def _room_initial_sync_parted( "room_id": room_id, "messages": { "chunk": ( - # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events( - messages, time_now, bundle_aggregations=False - ) + await self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( - # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now, bundle_aggregations=False + room_state.values(), time_now ) ), "presence": [], @@ -406,9 +392,8 @@ async def _room_initial_sync_joined( # TODO: These concurrently time_now = self.clock.time_msec() - # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), time_now, bundle_aggregations=False + current_state.values(), time_now ) now_token = self.hs.get_event_sources().get_current_token() @@ -482,10 +467,7 @@ async def get_receipts() -> List[JsonDict]: "room_id": room_id, "messages": { "chunk": ( - # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events( - messages, time_now, bundle_aggregations=False - ) + await self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 87f671708c4e..95b4fad3c68b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -247,7 +247,13 @@ async def get_state_events( room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(room_state.values(), now) + events = await self._event_serializer.serialize_events( + room_state.values(), + now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_relations=False, + ) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index deb353975143..3665d915133a 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -23,7 +23,7 @@ from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri -from authlib.oidc.core import CodeIDToken, UserInfo +from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template from pymacaroons.exceptions import ( @@ -117,8 +117,7 @@ async def load_metadata(self) -> None: for idp_id, p in self._providers.items(): try: await p.load_metadata() - if not p._uses_userinfo: - await p.load_jwks() + await p.load_jwks() except Exception as e: raise Exception( "Error while initialising OIDC provider %r" % (idp_id,) @@ -499,6 +498,10 @@ async def load_jwks(self, force: bool = False) -> JWKS: return await self._jwks.get() async def _load_jwks(self) -> JWKS: + if self._uses_userinfo: + # We're not using jwt signing, return an empty jwk set + return {"keys": []} + metadata = await self.load_metadata() # Load the JWKS using the `jwks_uri` metadata. @@ -660,7 +663,7 @@ async def _fetch_userinfo(self, token: Token) -> UserInfo: return UserInfo(resp) - async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: + async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: """Return an instance of UserInfo from token's ``id_token``. Args: @@ -670,7 +673,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: request. This value should match the one inside the token. Returns: - The decoded claims in the ID token. + An object representing the user. """ metadata = await self.load_metadata() claims_params = { @@ -681,6 +684,9 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: # If we got an `access_token`, there should be an `at_hash` claim # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwt = JsonWebToken(alg_values) @@ -697,7 +703,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=CodeIDToken, + claims_cls=claims_cls, claims_options=claim_options, claims_params=claims_params, ) @@ -707,7 +713,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=CodeIDToken, + claims_cls=claims_cls, claims_options=claim_options, claims_params=claims_params, ) @@ -715,8 +721,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: logger.debug("Decoded id_token JWT %r; validating", claims) claims.validate(leeway=120) # allows 2 min of clock skew - - return claims + return UserInfo(claims) async def handle_redirect_request( self, @@ -832,22 +837,8 @@ async def handle_oidc_callback( logger.debug("Successfully obtained OAuth2 token data: %r", token) - # If there is an id_token, it should be validated, regardless of the - # userinfo endpoint is used or not. - if token.get("id_token") is not None: - try: - id_token = await self._parse_id_token(token, nonce=session_data.nonce) - sid = id_token.get("sid") - except Exception as e: - logger.exception("Invalid id_token") - self._sso_handler.render_error(request, "invalid_token", str(e)) - return - else: - id_token = None - sid = None - - # Now that we have a token, get the userinfo either from the `id_token` - # claims or by fetching the `userinfo_endpoint`. + # Now that we have a token, get the userinfo, either by decoding the + # `id_token` or by fetching the `userinfo_endpoint`. if self._uses_userinfo: try: userinfo = await self._fetch_userinfo(token) @@ -855,14 +846,13 @@ async def handle_oidc_callback( logger.exception("Could not fetch userinfo") self._sso_handler.render_error(request, "fetch_error", str(e)) return - elif id_token is not None: - userinfo = UserInfo(id_token) else: - logger.error("Missing id_token in token response") - self._sso_handler.render_error( - request, "invalid_token", "Missing id_token in token response" - ) - return + try: + userinfo = await self._parse_id_token(token, nonce=session_data.nonce) + except Exception as e: + logger.exception("Invalid id_token") + self._sso_handler.render_error(request, "invalid_token", str(e)) + return # first check if we're doing a UIA if session_data.ui_auth_session_id: @@ -894,7 +884,7 @@ async def handle_oidc_callback( # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, session_data.client_redirect_url, sid + userinfo, token, request, session_data.client_redirect_url ) except MappingException as e: logger.exception("Could not map user") @@ -906,7 +896,6 @@ async def _complete_oidc_login( token: Token, request: SynapseRequest, client_redirect_url: str, - sid: Optional[str], ) -> None: """Given a UserInfo response, complete the login flow @@ -1019,7 +1008,6 @@ async def grandfather_existing_users() -> Optional[str]: oidc_response_to_user_attributes, grandfather_existing_users, extra_attributes, - auth_provider_session_id=sid, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 4f424380533b..cd64142735de 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -406,6 +406,9 @@ async def purge_room(self, room_id: str, force: bool = False) -> None: force: set true to skip checking for joined users. """ with await self.pagination_lock.write(room_id): + # check we know about the room + await self.store.get_room_version_id(room_id) + # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 454d06c9733d..3df872c578b5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -421,7 +421,7 @@ def __init__(self, hs: "HomeServer"): self._on_shutdown, ) - async def _on_shutdown(self) -> None: + def _on_shutdown(self) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f08a516a7588..448a36108e59 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -1,5 +1,4 @@ # Copyright 2014 - 2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -117,13 +116,9 @@ def __init__(self, hs: "HomeServer"): self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime - self.nonrefreshable_access_token_lifetime = ( - hs.config.registration.nonrefreshable_access_token_lifetime - ) self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) - self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime init_counters_for_auth_provider("") @@ -746,7 +741,6 @@ async def register_device( is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, - auth_provider_session_id: Optional[str] = None, ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. @@ -757,9 +751,9 @@ async def register_device( device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - auth_provider_id: The SSO IdP the user used, if any. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). should_issue_refresh_token: Whether it should also issue a refresh token - auth_provider_session_id: The session ID received during login from the SSO IdP. Returns: Tuple of device ID, access token, access token expiration time and refresh token """ @@ -770,8 +764,6 @@ async def register_device( is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, ) login_counter.labels( @@ -794,8 +786,6 @@ async def register_device_inner( is_guest: bool = False, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, ) -> LoginDict: """Helper for register_device @@ -803,86 +793,40 @@ async def register_device_inner( class and RegisterDeviceReplicationServlet. """ assert not self.hs.config.worker.worker_app - now_ms = self.clock.time_msec() - access_token_expiry = None + valid_until_ms = None if self.session_lifetime is not None: if is_guest: raise Exception( "session_lifetime is not currently implemented for guest access" ) - access_token_expiry = now_ms + self.session_lifetime - - if self.nonrefreshable_access_token_lifetime is not None: - if access_token_expiry is not None: - # Don't allow the non-refreshable access token to outlive the - # session. - access_token_expiry = min( - now_ms + self.nonrefreshable_access_token_lifetime, - access_token_expiry, - ) - else: - access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime + valid_until_ms = self.clock.time_msec() + self.session_lifetime refresh_token = None refresh_token_id = None registered_device_id = await self.device_handler.check_device_registered( - user_id, - device_id, - initial_display_name, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, + user_id, device_id, initial_display_name ) if is_guest: - assert access_token_expiry is None + assert valid_until_ms is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: if should_issue_refresh_token: - # A refreshable access token lifetime must be configured - # since we're told to issue a refresh token (the caller checks - # that this value is set before setting this flag). - assert self.refreshable_access_token_lifetime is not None - - # Set the expiry time of the refreshable access token - access_token_expiry = now_ms + self.refreshable_access_token_lifetime - - # Set the refresh token expiry time (if configured) - refresh_token_expiry = None - if self.refresh_token_lifetime is not None: - refresh_token_expiry = now_ms + self.refresh_token_lifetime - - # Set an ultimate session expiry time (if configured) - ultimate_session_expiry_ts = None - if self.session_lifetime is not None: - ultimate_session_expiry_ts = now_ms + self.session_lifetime - - # Also ensure that the issued tokens don't outlive the - # session. - # (It would be weird to configure a homeserver with a shorter - # session lifetime than token lifetime, but may as well handle - # it.) - access_token_expiry = min( - access_token_expiry, ultimate_session_expiry_ts - ) - if refresh_token_expiry is not None: - refresh_token_expiry = min( - refresh_token_expiry, ultimate_session_expiry_ts - ) - ( refresh_token, refresh_token_id, ) = await self._auth_handler.create_refresh_token_for_user_id( user_id, device_id=registered_device_id, - expiry_ts=refresh_token_expiry, - ultimate_session_expiry_ts=ultimate_session_expiry_ts, + ) + valid_until_ms = ( + self.clock.time_msec() + self.refreshable_access_token_lifetime ) access_token = await self._auth_handler.create_access_token_for_user_id( user_id, device_id=registered_device_id, - valid_until_ms=access_token_expiry, + valid_until_ms=valid_until_ms, is_appservice_ghost=is_appservice_ghost, refresh_token_id=refresh_token_id, ) @@ -890,7 +834,7 @@ class and RegisterDeviceReplicationServlet. return { "device_id": registered_device_id, "access_token": access_token, - "valid_until_ms": access_token_expiry, + "valid_until_ms": valid_until_ms, "refresh_token": refresh_token, } diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ead2198e14fe..88053f986997 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -46,7 +46,6 @@ from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -57,8 +56,6 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.streams import EventSource @@ -1223,147 +1220,6 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: return results -class TimestampLookupHandler: - def __init__(self, hs: "HomeServer"): - self.server_name = hs.hostname - self.store = hs.get_datastore() - self.state_handler = hs.get_state_handler() - self.federation_client = hs.get_federation_client() - - async def get_event_for_timestamp( - self, - requester: Requester, - room_id: str, - timestamp: int, - direction: str, - ) -> Tuple[str, int]: - """Find the closest event to the given timestamp in the given direction. - If we can't find an event locally or the event we have locally is next to a gap, - it will ask other federated homeservers for an event. - - Args: - requester: The user making the request according to the access token - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - A tuple containing the `event_id` closest to the given timestamp in - the given direction and the `origin_server_ts`. - - Raises: - SynapseError if unable to find any event locally in the given direction - """ - - local_event_id = await self.store.get_event_id_for_timestamp( - room_id, timestamp, direction - ) - logger.debug( - "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s", - local_event_id, - timestamp, - ) - - # Check for gaps in the history where events could be hiding in between - # the timestamp given and the event we were able to find locally - is_event_next_to_backward_gap = False - is_event_next_to_forward_gap = False - if local_event_id: - local_event = await self.store.get_event( - local_event_id, allow_none=False, allow_rejected=False - ) - - if direction == "f": - # We only need to check for a backward gap if we're looking forwards - # to ensure there is nothing in between. - is_event_next_to_backward_gap = ( - await self.store.is_event_next_to_backward_gap(local_event) - ) - elif direction == "b": - # We only need to check for a forward gap if we're looking backwards - # to ensure there is nothing in between - is_event_next_to_forward_gap = ( - await self.store.is_event_next_to_forward_gap(local_event) - ) - - # If we found a gap, we should probably ask another homeserver first - # about more history in between - if ( - not local_event_id - or is_event_next_to_backward_gap - or is_event_next_to_forward_gap - ): - logger.debug( - "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first", - local_event_id, - timestamp, - ) - - # Find other homeservers from the given state in the room - curr_state = await self.state_handler.get_current_state(room_id) - curr_domains = get_domains_from_state(curr_state) - likely_domains = [ - domain for domain, depth in curr_domains if domain != self.server_name - ] - - # Loop through each homeserver candidate until we get a succesful response - for domain in likely_domains: - try: - remote_response = await self.federation_client.timestamp_to_event( - domain, room_id, timestamp, direction - ) - logger.debug( - "get_event_for_timestamp: response from domain(%s)=%s", - domain, - remote_response, - ) - - # TODO: Do we want to persist this as an extremity? - # TODO: I think ideally, we would try to backfill from - # this event and run this whole - # `get_event_for_timestamp` function again to make sure - # they didn't give us an event from their gappy history. - remote_event_id = remote_response.event_id - origin_server_ts = remote_response.origin_server_ts - - # Only return the remote event if it's closer than the local event - if not local_event or ( - abs(origin_server_ts - timestamp) - < abs(local_event.origin_server_ts - timestamp) - ): - return remote_event_id, origin_server_ts - except (HttpResponseException, InvalidResponseError) as ex: - # Let's not put a high priority on some other homeserver - # failing to respond or giving a random response - logger.debug( - "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", - domain, - type(ex).__name__, - ex, - ex.args, - ) - except Exception as ex: - # But we do want to see some exceptions in our code - logger.warning( - "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", - domain, - type(ex).__name__, - ex, - ex.args, - ) - - if not local_event_id: - raise SynapseError( - 404, - "Unable to find event from %s in direction %s" % (timestamp, direction), - errcode=Codes.NOT_FOUND, - ) - - return local_event_id, local_event.origin_server_ts - - class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1535,13 +1391,20 @@ async def shutdown_room( await self.store.block_room(room_id, requester_user_id) if not await self.store.get_room(room_id): - # if we don't know about the room, there is nothing left to do. - return { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } + if block: + # We allow you to block an unknown room. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + else: + # But if you don't want to preventatively block another room, + # this function can't do anything useful. + raise NotFoundError( + "Cannot shut down room: unknown room id %s" % (room_id,) + ) if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index b2cfe537dfb1..8181cc0b5267 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,9 +36,8 @@ SynapseError, UnsupportedRoomVersionError, ) -from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict, Requester +from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -94,9 +93,6 @@ def __init__(self, hs: "HomeServer"): self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() - self._ratelimiter = Ratelimiter( - store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 - ) # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. @@ -253,7 +249,7 @@ async def get_space_summary( async def get_room_hierarchy( self, - requester: Requester, + requester: str, requested_room_id: str, suggested_only: bool = False, max_depth: Optional[int] = None, @@ -280,8 +276,6 @@ async def get_room_hierarchy( Returns: The JSON hierarchy dictionary. """ - await self._ratelimiter.ratelimit(requester) - # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. # @@ -289,7 +283,7 @@ async def get_room_hierarchy( # to process multiple requests for the same page will result in errors. return await self._pagination_response_cache.wrap( ( - requester.user.to_string(), + requester, requested_room_id, suggested_only, max_depth, @@ -297,7 +291,7 @@ async def get_room_hierarchy( from_token, ), self._get_room_hierarchy, - requester.user.to_string(), + requester, requested_room_id, suggested_only, max_depth, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 65c27bc64a5e..49fde01cf0e7 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -365,7 +365,6 @@ async def complete_sso_login_request( sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, - auth_provider_session_id: Optional[str] = None, ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -416,8 +415,6 @@ async def complete_sso_login_request( extra_login_attributes: An optional dictionary of extra attributes to be provided to the client in the login response. - auth_provider_session_id: An optional session ID from the IdP. - Raises: MappingException if there was a problem mapping the response to a user. RedirectException: if the mapping provider needs to redirect the user @@ -493,7 +490,6 @@ async def complete_sso_login_request( client_redirect_url, extra_login_attributes, new_user=new_user, - auth_provider_session_id=auth_provider_session_id, ) async def _call_attribute_mapper( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f3039c3c3fb7..891435c14daf 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -334,19 +334,6 @@ async def _wait_for_sync_for_user( full_state: bool, cache_context: ResponseCacheContext[SyncRequestKey], ) -> SyncResult: - """The start of the machinery that produces a /sync response. - - See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details. - - This method does high-level bookkeeping: - - tracking the kind of sync in the logging context - - deleting any to_device messages whose delivery has been acknowledged. - - deciding if we should dispatch an instant or delayed response - - marking the sync as being lazily loaded, if appropriate - - Computing the body of the response begins in the next method, - `current_sync_for_user`. - """ if since_token is None: sync_type = "initial_sync" elif full_state: @@ -376,7 +363,7 @@ async def _wait_for_sync_for_user( sync_config, since_token, full_state=full_state ) else: - # Otherwise, we wait for something to happen and report it to the user. + async def current_sync_callback( before_token: StreamToken, after_token: StreamToken ) -> SyncResult: @@ -415,12 +402,7 @@ async def current_sync_for_user( since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates the response body of a sync result, represented as a SyncResult. - - This is a wrapper around `generate_sync_result` which starts an open tracing - span to track the sync. See `generate_sync_result` for the next part of your - indoctrination. - """ + """Get the sync for client needed to match what the server has now.""" with start_active_span("current_sync_for_user"): log_kv({"since_token": since_token}) sync_result = await self.generate_sync_result( @@ -578,7 +560,7 @@ async def _load_filtered_recents( # that have happened since `since_key` up to `end_key`, so we # can just use `get_room_events_stream_for_room`. # Otherwise, we want to return the last N events in the room - # in topological ordering. + # in toplogical ordering. if since_key: events, end_key = await self.store.get_room_events_stream_for_room( room_id, @@ -1060,18 +1042,7 @@ async def generate_sync_result( since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates the response body of a sync result. - - This is represented by a `SyncResult` struct, which is built from small pieces - using a `SyncResultBuilder`. See also - https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync - the `sync_result_builder` is passed as a mutable ("inout") parameter to various - helper functions. These retrieve and process the data which forms the sync body, - often writing to the `sync_result_builder` to store their output. - - At the end, we transfer data from the `sync_result_builder` to a new `SyncResult` - instance to signify that the sync calculation is complete. - """ + """Generates a sync result.""" # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1373,22 +1344,14 @@ async def _generate_sync_entry_for_to_device( async def _generate_sync_entry_for_account_data( self, sync_result_builder: "SyncResultBuilder" ) -> Dict[str, Dict[str, JsonDict]]: - """Generates the account data portion of the sync response. - - Account data (called "Client Config" in the spec) can be set either globally - or for a specific room. Account data consists of a list of events which - accumulate state, much like a room. - - This function retrieves global and per-room account data. The former is written - to the given `sync_result_builder`. The latter is returned directly, to be - later written to the `sync_result_builder` on a room-by-room basis. + """Generates the account data portion of the sync response. Populates + `sync_result_builder` with the result. Args: sync_result_builder Returns: - A dictionary whose keys (room ids) map to the per room account data for that - room. + A dictionary containing the per room account data. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() @@ -1396,7 +1359,7 @@ async def _generate_sync_entry_for_account_data( if since_token and not sync_result_builder.full_state: ( - global_account_data, + account_data, account_data_by_room, ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key @@ -1407,23 +1370,23 @@ async def _generate_sync_entry_for_account_data( ) if push_rules_changed: - global_account_data["m.push_rules"] = await self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( - global_account_data, + account_data, account_data_by_room, ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - global_account_data["m.push_rules"] = await self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) account_data_for_user = await sync_config.filter_collection.filter_account_data( [ {"type": account_data_type, "content": content} - for account_data_type, content in global_account_data.items() + for account_data_type, content in account_data.items() ] ) @@ -1497,31 +1460,18 @@ async def _generate_sync_entry_for_rooms( """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. - In the response that reaches the client, rooms are divided into four categories: - `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of - room ids returned by this function. - Args: sync_result_builder account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple describing rooms the user has joined or left, and users who've - joined or left rooms any rooms the user is in. This gets used later in - `_generate_sync_entry_for_device_list`. - - Its entries are: - - newly_joined_rooms - - newly_joined_or_invited_or_knocked_users - - newly_left_rooms - - newly_left_users + Returns a 4-tuple of + `(newly_joined_rooms, newly_joined_or_invited_users, + newly_left_rooms, newly_left_users)` """ - since_token = sync_result_builder.since_token - - # 1. Start by fetching all ephemeral events in rooms we've joined (if required). user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - since_token is None + sync_result_builder.since_token is None and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) @@ -1535,8 +1485,9 @@ async def _generate_sync_entry_for_rooms( ) sync_result_builder.now_token = now_token - # 2. We check up front if anything has changed, if it hasn't then there is + # We check up front if anything has changed, if it hasn't then there is # no point in going further. + since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: have_changed = await self._have_rooms_changed(sync_result_builder) @@ -1549,8 +1500,20 @@ async def _generate_sync_entry_for_rooms( logger.debug("no-oping sync") return set(), set(), set(), set() - # 3. Work out which rooms need reporting in the sync response. - ignored_users = await self._get_ignored_users(user_id) + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) + ) + + # If there is ignored users account data and it matches the proper type, + # then use it. + ignored_users: FrozenSet[str] = frozenset() + if ignored_account_data: + ignored_users_data = ignored_account_data.get("ignored_users", {}) + if isinstance(ignored_users_data, dict): + ignored_users = frozenset(ignored_users_data.keys()) + if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1560,6 +1523,7 @@ async def _generate_sync_entry_for_rooms( ) else: room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) + tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1570,8 +1534,6 @@ async def _generate_sync_entry_for_rooms( newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms - # 4. We need to apply further processing to `room_entries` (rooms considered - # joined or archived). async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( @@ -1590,13 +1552,31 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) - # 5. Work out which users have joined or left rooms we're in. We use this - # to build the device_list part of the sync response in - # `_generate_sync_entry_for_device_list`. - ( - newly_joined_or_invited_or_knocked_users, - newly_left_users, - ) = sync_result_builder.calculate_user_changes() + # Now we want to get any newly joined, invited or knocking users + newly_joined_or_invited_or_knocked_users = set() + newly_left_users = set() + if since_token: + for joined_sync in sync_result_builder.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if ( + event.membership == Membership.JOIN + or event.membership == Membership.INVITE + or event.membership == Membership.KNOCK + ): + newly_joined_or_invited_or_knocked_users.add( + event.state_key + ) + else: + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership == Membership.JOIN: + newly_left_users.add(event.state_key) + + newly_left_users -= newly_joined_or_invited_or_knocked_users return ( set(newly_joined_rooms), @@ -1605,36 +1585,11 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: newly_left_users, ) - async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: - """Retrieve the users ignored by the given user from their global account_data. - - Returns an empty set if - - there is no global account_data entry for ignored_users - - there is such an entry, but it's not a JSON object. - """ - # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - return ignored_users - async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. - - Does not modify the `sync_result_builder`. """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1642,13 +1597,12 @@ async def _have_rooms_changed( assert since_token - # Get a list of membership change events that have happened to the user - # requesting the sync. - membership_changes = await self.store.get_membership_changes_for_user( + # Get a list of membership change events that have happened. + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) - if membership_changes: + if rooms_changed: return True stream_id = since_token.room_key.stream @@ -1660,25 +1614,7 @@ async def _have_rooms_changed( async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Determine the changes in rooms to report to the user. - - Ideally, we want to report all events whose stream ordering `s` lies in the - range `since_token < s <= now_token`, where the two tokens are read from the - sync_result_builder. - - If there are too many events in that range to report, things get complicated. - In this situation we return a truncated list of the most recent events, and - indicate in the response that there is a "gap" of omitted events. Additionally: - - - we include a "state_delta", to describe the changes in state over the gap, - - we include all membership events applying to the user making the request, - even those in the gap. - - See the spec for the rationale: - https://spec.matrix.org/v1.1/client-server-api/#syncing - - The sync_result_builder is not modified by this function. - """ + """Gets the the changes that have happened since the last sync.""" user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token @@ -1686,36 +1622,21 @@ async def _get_rooms_changed( assert since_token - # The spec - # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync - # notes that membership events need special consideration: - # - # > When a sync is limited, the server MUST return membership events for events - # > in the gap (between since and the start of the returned timeline), regardless - # > as to whether or not they are redundant. - # - # We fetch such events here, but we only seem to use them for categorising rooms - # as newly joined, newly left, invited or knocked. - # TODO: we've already called this function and ran this query in - # _have_rooms_changed. We could keep the results in memory to avoid a - # second query, at the cost of more complicated source code. - membership_change_events = await self.store.get_membership_changes_for_user( + # Get a list of membership change events that have happened. + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} - for event in membership_change_events: + for event in rooms_changed: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - newly_joined_rooms: List[str] = [] - newly_left_rooms: List[str] = [] - room_entries: List[RoomSyncResultBuilder] = [] - invited: List[InvitedSyncResult] = [] - knocked: List[KnockedSyncResult] = [] + newly_joined_rooms = [] + newly_left_rooms = [] + room_entries = [] + invited = [] + knocked = [] for room_id, events in mem_change_events_by_room_id.items(): - # The body of this loop will add this room to at least one of the five lists - # above. Things get messy if you've e.g. joined, left, joined then left the - # room all in the same sync period. logger.debug( "Membership changes in %s: [%s]", room_id, @@ -1770,7 +1691,6 @@ async def _get_rooms_changed( if not non_joins: continue - last_non_join = non_joins[-1] # Check if we have left the room. This can either be because we were # joined before *or* that we since joined and then left. @@ -1792,18 +1712,18 @@ async def _get_rooms_changed( newly_left_rooms.append(room_id) # Only bother if we're still currently invited - should_invite = last_non_join.membership == Membership.INVITE + should_invite = non_joins[-1].membership == Membership.INVITE if should_invite: - if last_non_join.sender not in ignored_users: - invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join) + if event.sender not in ignored_users: + invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) if invite_room_sync: invited.append(invite_room_sync) # Only bother if our latest membership in the room is knock (and we haven't # been accepted/rejected in the meantime). - should_knock = last_non_join.membership == Membership.KNOCK + should_knock = non_joins[-1].membership == Membership.KNOCK if should_knock: - knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join) + knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1]) if knock_room_sync: knocked.append(knock_room_sync) @@ -1861,9 +1781,7 @@ async def _get_rooms_changed( timeline_limit = sync_config.filter_collection.timeline_limit() - # Get all events since the `from_key` in rooms we're currently joined to. - # If there are too many, we get the most recent events only. This leaves - # a "gap" in the timeline, as described by the spec for /sync. + # Get all events for rooms we're currently joined to. room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, @@ -1924,10 +1842,6 @@ async def _get_all_rooms( ) -> _RoomChanges: """Returns entries for all rooms for the user. - Like `_get_rooms_changed`, but assumes the `since_token` is `None`. - - This function does not modify the sync_result_builder. - Args: sync_result_builder ignored_users: Set of users ignored by user. @@ -1939,9 +1853,16 @@ async def _get_all_rooms( now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config + membership_list = ( + Membership.INVITE, + Membership.KNOCK, + Membership.JOIN, + Membership.LEAVE, + Membership.BAN, + ) + room_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, - membership_list=Membership.LIST, + user_id=user_id, membership_list=membership_list ) room_entries = [] @@ -2291,7 +2212,8 @@ def _calculate_state( # to only include membership events for the senders in the timeline. # In practice, we can do this by removing them from the p_ids list, # which is the list of relevant state we know we have already sent to the client. - # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 + # see https://github.com/matrix-org/synapse/pull/2970 + # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 if lazy_load_members: p_ids.difference_update( @@ -2340,39 +2262,6 @@ class SyncResultBuilder: groups: Optional[GroupsSyncResult] = None to_device: List[JsonDict] = attr.Factory(list) - def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: - """Work out which other users have joined or left rooms we are joined to. - - This data only is only useful for an incremental sync. - - The SyncResultBuilder is not modified by this function. - """ - newly_joined_or_invited_or_knocked_users = set() - newly_left_users = set() - if self.since_token: - for joined_sync in self.joined: - it = itertools.chain( - joined_sync.timeline.events, joined_sync.state.values() - ) - for event in it: - if event.type == EventTypes.Member: - if ( - event.membership == Membership.JOIN - or event.membership == Membership.INVITE - or event.membership == Membership.KNOCK - ): - newly_joined_or_invited_or_knocked_users.add( - event.state_key - ) - else: - prev_content = event.unsigned.get("prev_content", {}) - prev_membership = prev_content.get("membership", None) - if prev_membership == Membership.JOIN: - newly_left_users.add(event.state_key) - - newly_left_users -= newly_joined_or_invited_or_knocked_users - return newly_joined_or_invited_or_knocked_users, newly_left_users - @attr.s(slots=True, auto_attribs=True) class RoomSyncResultBuilder: diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6dd9b9ad0358..91ba93372c2c 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -79,35 +79,6 @@ def parse_integer( return parse_integer_from_args(args, name, default, required) -@overload -def parse_integer_from_args( - args: Mapping[bytes, Sequence[bytes]], - name: str, - default: Optional[int] = None, -) -> Optional[int]: - ... - - -@overload -def parse_integer_from_args( - args: Mapping[bytes, Sequence[bytes]], - name: str, - *, - required: Literal[True], -) -> int: - ... - - -@overload -def parse_integer_from_args( - args: Mapping[bytes, Sequence[bytes]], - name: str, - default: Optional[int] = None, - required: bool = False, -) -> Optional[int]: - ... - - def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 662e60bc3394..96d7a8f2a95b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -24,7 +24,6 @@ List, Optional, Tuple, - TypeVar, Union, ) @@ -82,19 +81,10 @@ ) from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging.context import ( - defer_to_thread, - make_deferred_yieldable, - run_in_background, -) +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client.login import LoginResponse from synapse.storage import DataStore -from synapse.storage.background_updates import ( - DEFAULT_BATCH_SIZE_CALLBACK, - MIN_BATCH_SIZE_CALLBACK, - ON_UPDATE_CALLBACK, -) from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -108,16 +98,12 @@ create_requester, ) from synapse.util import Clock -from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.app.generic_worker import GenericWorkerSlavedStore from synapse.server import HomeServer - -T = TypeVar("T") - """ This package defines the 'stable' API which can be used by extension modules which are loaded into Synapse. @@ -321,25 +307,7 @@ def register_password_auth_provider_callbacks( auth_checkers=auth_checkers, ) - def register_background_update_controller_callbacks( - self, - on_update: ON_UPDATE_CALLBACK, - default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, - min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None, - ) -> None: - """Registers background update controller callbacks. - - Added in Synapse v1.49.0. - """ - - for db in self._hs.get_datastores().databases: - db.updates.register_update_controller_callbacks( - on_update=on_update, - default_batch_size=default_batch_size, - min_batch_size=min_batch_size, - ) - - def register_web_resource(self, path: str, resource: Resource) -> None: + def register_web_resource(self, path: str, resource: Resource): """Registers a web resource to be served at the given path. This function should be called during initialisation of the module. @@ -464,7 +432,7 @@ def get_qualified_user_id(self, username: str) -> str: username: provided user id Returns: - qualified @user:id + str: qualified @user:id """ if username.startswith("@"): return username @@ -500,7 +468,7 @@ async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: """ return await self._store.user_get_threepids(user_id) - def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": + def check_user_exists(self, user_id: str): """Check if user exists. Added in Synapse v0.25.0. @@ -509,18 +477,13 @@ def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": user_id: Complete @user:id Returns: - Canonical (case-corrected) user_id, or None + Deferred[str|None]: Canonical (case-corrected) user_id, or None if the user is not registered. """ return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks - def register( - self, - localpart: str, - displayname: Optional[str] = None, - emails: Optional[List[str]] = None, - ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]: + def register(self, localpart, displayname=None, emails: Optional[List[str]] = None): """Registers a new user with given localpart and optional displayname, emails. Also returns an access token for the new user. @@ -532,12 +495,12 @@ def register( Added in Synapse v0.25.0. Args: - localpart: The localpart of the new user. - displayname: The displayname of the new user. - emails: Emails to bind to the new user. + localpart (str): The localpart of the new user. + displayname (str|None): The displayname of the new user. + emails (List[str]): Emails to bind to the new user. Returns: - a 2-tuple of (user_id, access_token) + Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token) """ logger.warning( "Using deprecated ModuleApi.register which creates a dummy user device." @@ -547,26 +510,23 @@ def register( return user_id, access_token def register_user( - self, - localpart: str, - displayname: Optional[str] = None, - emails: Optional[List[str]] = None, - ) -> "defer.Deferred[str]": + self, localpart, displayname=None, emails: Optional[List[str]] = None + ): """Registers a new user with given localpart and optional displayname, emails. Added in Synapse v1.2.0. Args: - localpart: The localpart of the new user. - displayname: The displayname of the new user. - emails: Emails to bind to the new user. + localpart (str): The localpart of the new user. + displayname (str|None): The displayname of the new user. + emails (List[str]): Emails to bind to the new user. Raises: SynapseError if there is an error performing the registration. Check the 'errcode' property for more information on the reason for failure Returns: - user_id + defer.Deferred[str]: user_id """ return defer.ensureDeferred( self._hs.get_registration_handler().register_user( @@ -576,25 +536,20 @@ def register_user( ) ) - def register_device( - self, - user_id: str, - device_id: Optional[str] = None, - initial_display_name: Optional[str] = None, - ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]": + def register_device(self, user_id, device_id=None, initial_display_name=None): """Register a device for a user and generate an access token. Added in Synapse v1.2.0. Args: - user_id: full canonical @user:id - device_id: The device ID to check, or None to generate + user_id (str): full canonical @user:id + device_id (str|None): The device ID to check, or None to generate a new one. - initial_display_name: An optional display name for the + initial_display_name (str|None): An optional display name for the device. Returns: - Tuple of device ID, access token, access token expiration time and refresh token + defer.Deferred[tuple[str, str]]: Tuple of device ID and access token """ return defer.ensureDeferred( self._hs.get_registration_handler().register_device( @@ -627,7 +582,6 @@ def generate_short_term_login_token( user_id: str, duration_in_ms: int = (2 * 60 * 1000), auth_provider_id: str = "", - auth_provider_session_id: Optional[str] = None, ) -> str: """Generate a login token suitable for m.login.token authentication @@ -645,14 +599,11 @@ def generate_short_term_login_token( return self._hs.get_macaroon_generator().generate_short_term_login_token( user_id, auth_provider_id, - auth_provider_session_id, duration_in_ms, ) @defer.inlineCallbacks - def invalidate_access_token( - self, access_token: str - ) -> Generator["defer.Deferred[Any]", Any, None]: + def invalidate_access_token(self, access_token): """Invalidate an access token for a user Added in Synapse v0.25.0. @@ -684,20 +635,14 @@ def invalidate_access_token( self._auth_handler.delete_access_token(access_token) ) - def run_db_interaction( - self, - desc: str, - func: Callable[..., T], - *args: Any, - **kwargs: Any, - ) -> "defer.Deferred[T]": + def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection Added in Synapse v0.25.0. Args: - desc: description for the transaction, for metrics etc - func: function to be run. Passed a database cursor object + desc (str): description for the transaction, for metrics etc + func (func): function to be run. Passed a database cursor object as well as *args and **kwargs *args: positional args to be passed to func **kwargs: named args to be passed to func @@ -711,7 +656,7 @@ def run_db_interaction( def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str - ) -> None: + ): """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. @@ -741,7 +686,7 @@ async def complete_sso_login_async( client_redirect_url: str, new_user: bool = False, auth_provider_id: str = "", - ) -> None: + ): """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. @@ -980,11 +925,11 @@ def looping_background_call( self, f: Callable, msec: float, - *args: object, + *args, desc: Optional[str] = None, run_on_all_instances: bool = False, - **kwargs: object, - ) -> None: + **kwargs, + ): """Wraps a function as a background process and calls it repeatedly. NOTE: Will only run on the instance that is configured to run @@ -1015,7 +960,9 @@ def looping_background_call( run_as_background_process, msec, desc, - lambda: maybe_awaitable(f(*args, **kwargs)), + f, + *args, + **kwargs, ) else: logger.warning( @@ -1023,18 +970,13 @@ def looping_background_call( f, ) - async def sleep(self, seconds: float) -> None: - """Sleeps for the given number of seconds.""" - - await self._clock.sleep(seconds) - async def send_mail( self, recipient: str, subject: str, html: str, text: str, - ) -> None: + ): """Send an email on behalf of the homeserver. Added in Synapse v1.39.0. @@ -1182,26 +1124,6 @@ async def get_room_state( return {key: state_events[event_id] for key, event_id in state_ids.items()} - async def defer_to_thread( - self, - f: Callable[..., T], - *args: Any, - **kwargs: Any, - ) -> T: - """Runs the given function in a separate thread from Synapse's thread pool. - - Added in Synapse v1.49.0. - - Args: - f: The function to run. - args: The function's arguments. - kwargs: The function's keyword arguments. - - Returns: - The return value of the function once ran in a thread. - """ - return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs) - class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 4f13c0418ab9..cf5abdfbda49 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -21,8 +21,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer -from synapse.push.push_types import EmailReason -from synapse.storage.databases.main.event_push_actions import EmailPushAction from synapse.util.threepids import validate_email if TYPE_CHECKING: @@ -192,7 +190,7 @@ async def _unsafe_process(self) -> None: # we then consider all previously outstanding notifications # to be delivered. - reason: EmailReason = { + reason = { "room_id": push_action["room_id"], "now": self.clock.time_msec(), "received_at": received_at, @@ -277,7 +275,7 @@ def room_ready_to_notify_at(self, room_id: str) -> int: return may_send_at async def sent_notif_update_throttle( - self, room_id: str, notified_push_action: EmailPushAction + self, room_id: str, notified_push_action: dict ) -> None: # We have sent a notification, so update the throttle accordingly. # If the event that triggered the notif happened more than @@ -317,9 +315,7 @@ async def sent_notif_update_throttle( self.pusher_id, room_id, self.throttle_params[room_id] ) - async def send_notification( - self, push_actions: List[EmailPushAction], reason: EmailReason - ) -> None: + async def send_notification(self, push_actions: List[dict], reason: dict) -> None: logger.info("Sending notif email for user %r", self.user_id) await self.mailer.send_notification_mail( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 3fa603ccb7f7..dbf4ad7f97ee 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -26,7 +26,6 @@ from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException -from synapse.storage.databases.main.event_push_actions import HttpPushAction from . import push_rule_evaluator, push_tools @@ -274,7 +273,7 @@ async def _unsafe_process(self) -> None: ) break - async def _process_one(self, push_action: HttpPushAction) -> bool: + async def _process_one(self, push_action: dict) -> bool: if "notify" not in push_action["actions"]: return True diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ba4f866487ec..ce299ba3da16 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -14,7 +14,7 @@ import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar import bleach import jinja2 @@ -28,14 +28,6 @@ descriptor_from_member_events, name_from_member_event, ) -from synapse.push.push_types import ( - EmailReason, - MessageVars, - NotifVars, - RoomVars, - TemplateVars, -) -from synapse.storage.databases.main.event_push_actions import EmailPushAction from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID from synapse.util.async_helpers import concurrently_execute @@ -143,7 +135,7 @@ async def send_password_reset_mail( % urllib.parse.urlencode(params) ) - template_vars: TemplateVars = {"link": link} + template_vars = {"link": link} await self.send_email( email_address, @@ -173,7 +165,7 @@ async def send_registration_mail( % urllib.parse.urlencode(params) ) - template_vars: TemplateVars = {"link": link} + template_vars = {"link": link} await self.send_email( email_address, @@ -204,7 +196,7 @@ async def send_add_threepid_mail( % urllib.parse.urlencode(params) ) - template_vars: TemplateVars = {"link": link} + template_vars = {"link": link} await self.send_email( email_address, @@ -218,8 +210,8 @@ async def send_notification_mail( app_id: str, user_id: str, email_address: str, - push_actions: Iterable[EmailPushAction], - reason: EmailReason, + push_actions: Iterable[Dict[str, Any]], + reason: Dict[str, Any], ) -> None: """ Send email regarding a user's room notifications @@ -238,7 +230,7 @@ async def send_notification_mail( [pa["event_id"] for pa in push_actions] ) - notifs_by_room: Dict[str, List[EmailPushAction]] = {} + notifs_by_room: Dict[str, List[Dict[str, Any]]] = {} for pa in push_actions: notifs_by_room.setdefault(pa["room_id"], []).append(pa) @@ -266,7 +258,7 @@ async def _fetch_room_state(room_id: str) -> None: # actually sort our so-called rooms_in_order list, most recent room first rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) - rooms: List[RoomVars] = [] + rooms: List[Dict[str, Any]] = [] for r in rooms_in_order: roomvars = await self._get_room_vars( @@ -297,7 +289,7 @@ async def _fetch_room_state(room_id: str) -> None: notifs_by_room, state_by_room, notif_events, reason ) - template_vars: TemplateVars = { + template_vars = { "user_display_name": user_display_name, "unsubscribe_link": self._make_unsubscribe_link( user_id, app_id, email_address @@ -310,10 +302,10 @@ async def _fetch_room_state(room_id: str) -> None: await self.send_email(email_address, summary_text, template_vars) async def send_email( - self, email_address: str, subject: str, extra_template_vars: TemplateVars + self, email_address: str, subject: str, extra_template_vars: Dict[str, Any] ) -> None: """Send an email with the given information and template text""" - template_vars: TemplateVars = { + template_vars = { "app_name": self.app_name, "server_name": self.hs.config.server.server_name, } @@ -335,10 +327,10 @@ async def _get_room_vars( self, room_id: str, user_id: str, - notifs: Iterable[EmailPushAction], + notifs: Iterable[Dict[str, Any]], notif_events: Dict[str, EventBase], room_state_ids: StateMap[str], - ) -> RoomVars: + ) -> Dict[str, Any]: """ Generate the variables for notifications on a per-room basis. @@ -364,7 +356,7 @@ async def _get_room_vars( room_name = await calculate_room_name(self.store, room_state_ids, user_id) - room_vars: RoomVars = { + room_vars: Dict[str, Any] = { "title": room_name, "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], @@ -425,11 +417,11 @@ async def _get_room_avatar( async def _get_notif_vars( self, - notif: EmailPushAction, + notif: Dict[str, Any], user_id: str, notif_event: EventBase, room_state_ids: StateMap[str], - ) -> NotifVars: + ) -> Dict[str, Any]: """ Generate the variables for a single notification. @@ -450,7 +442,7 @@ async def _get_notif_vars( after_limit=CONTEXT_AFTER, ) - ret: NotifVars = { + ret = { "link": self._make_notif_link(notif), "ts": notif["received_ts"], "messages": [], @@ -469,8 +461,8 @@ async def _get_notif_vars( return ret async def _get_message_vars( - self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str] - ) -> Optional[MessageVars]: + self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str] + ) -> Optional[Dict[str, Any]]: """ Generate the variables for a single event, if possible. @@ -502,9 +494,7 @@ async def _get_message_vars( if sender_state_event: sender_name = name_from_member_event(sender_state_event) - sender_avatar_url: Optional[str] = sender_state_event.content.get( - "avatar_url" - ) + sender_avatar_url = sender_state_event.content.get("avatar_url") else: # No state could be found, fallback to the MXID. sender_name = event.sender @@ -514,7 +504,7 @@ async def _get_message_vars( # sender_hash % the number of default images to choose from sender_hash = string_ordinal_total(event.sender) - ret: MessageVars = { + ret = { "event_type": event.type, "is_historical": event.event_id != notif["event_id"], "id": event.event_id, @@ -529,8 +519,6 @@ async def _get_message_vars( return ret msgtype = event.content.get("msgtype") - if not isinstance(msgtype, str): - msgtype = None ret["msgtype"] = msgtype @@ -545,7 +533,7 @@ async def _get_message_vars( return ret def _add_text_message_vars( - self, messagevars: MessageVars, event: EventBase + self, messagevars: Dict[str, Any], event: EventBase ) -> None: """ Potentially add a sanitised message body to the message variables. @@ -555,8 +543,8 @@ def _add_text_message_vars( event: The event under consideration. """ msgformat = event.content.get("format") - if not isinstance(msgformat, str): - msgformat = None + + messagevars["format"] = msgformat formatted_body = event.content.get("formatted_body") body = event.content.get("body") @@ -567,7 +555,7 @@ def _add_text_message_vars( messagevars["body_text_html"] = safe_text(body) def _add_image_message_vars( - self, messagevars: MessageVars, event: EventBase + self, messagevars: Dict[str, Any], event: EventBase ) -> None: """ Potentially add an image URL to the message variables. @@ -582,7 +570,7 @@ def _add_image_message_vars( async def _make_summary_text_single_room( self, room_id: str, - notifs: List[EmailPushAction], + notifs: List[Dict[str, Any]], room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], user_id: str, @@ -697,10 +685,10 @@ async def _make_summary_text_single_room( async def _make_summary_text( self, - notifs_by_room: Dict[str, List[EmailPushAction]], + notifs_by_room: Dict[str, List[Dict[str, Any]]], room_state_ids: Dict[str, StateMap[str]], notif_events: Dict[str, EventBase], - reason: EmailReason, + reason: Dict[str, Any], ) -> str: """ Make a summary text for the email when multiple rooms have notifications. @@ -730,7 +718,7 @@ async def _make_summary_text( async def _make_summary_text_from_member_events( self, room_id: str, - notifs: List[EmailPushAction], + notifs: List[Dict[str, Any]], room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], ) -> str: @@ -817,7 +805,7 @@ def _make_room_link(self, room_id: str) -> str: base_url = "https://matrix.to/#" return "%s/%s" % (base_url, room_id) - def _make_notif_link(self, notif: EmailPushAction) -> str: + def _make_notif_link(self, notif: Dict[str, str]) -> str: """ Generate a link to open an event in the web client. diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py deleted file mode 100644 index 8d16ab62cef6..000000000000 --- a/synapse/push/push_types.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional - -from typing_extensions import TypedDict - - -class EmailReason(TypedDict, total=False): - """ - Information on the event that triggered the email to be sent - - room_id: the ID of the room the event was sent in - now: timestamp in ms when the email is being sent out - room_name: a human-readable name for the room the event was sent in - received_at: the time in milliseconds at which the event was received - delay_before_mail_ms: the amount of time in milliseconds Synapse always waits - before ever emailing about a notification (to give the user a chance to respond - to other push or notice the window) - last_sent_ts: the time in milliseconds at which a notification was last sent - for an event in this room - throttle_ms: the minimum amount of time in milliseconds between two - notifications can be sent for this room - """ - - room_id: str - now: int - room_name: Optional[str] - received_at: int - delay_before_mail_ms: int - last_sent_ts: int - throttle_ms: int - - -class MessageVars(TypedDict, total=False): - """ - Details about a specific message to include in a notification - - event_type: the type of the event - is_historical: a boolean, which is `False` if the message is the one - that triggered the notification, `True` otherwise - id: the ID of the event - ts: the time in milliseconds at which the event was sent - sender_name: the display name for the event's sender - sender_avatar_url: the avatar URL (as a `mxc://` URL) for the event's - sender - sender_hash: a hash of the user ID of the sender - msgtype: the type of the message - body_text_html: html representation of the message - body_text_plain: plaintext representation of the message - image_url: mxc url of an image, when "msgtype" is "m.image" - """ - - event_type: str - is_historical: bool - id: str - ts: int - sender_name: str - sender_avatar_url: Optional[str] - sender_hash: int - msgtype: Optional[str] - body_text_html: str - body_text_plain: str - image_url: str - - -class NotifVars(TypedDict): - """ - Details about an event we are about to include in a notification - - link: a `matrix.to` link to the event - ts: the time in milliseconds at which the event was received - messages: a list of messages containing one message before the event, the - message in the event, and one message after the event. - """ - - link: str - ts: Optional[int] - messages: List[MessageVars] - - -class RoomVars(TypedDict): - """ - Represents a room containing events to include in the email. - - title: a human-readable name for the room - hash: a hash of the ID of the room - invite: a boolean, which is `True` if the room is an invite the user hasn't - accepted yet, `False` otherwise - notifs: a list of events, or an empty list if `invite` is `True`. - link: a `matrix.to` link to the room - avator_url: url to the room's avator - """ - - title: Optional[str] - hash: int - invite: bool - notifs: List[NotifVars] - link: str - avatar_url: Optional[str] - - -class TemplateVars(TypedDict, total=False): - """ - Generic structure for passing to the email sender, can hold all the fields used in email templates. - - app_name: name of the app/service this homeserver is associated with - server_name: name of our own homeserver - link: a link to include into the email to be sent - user_display_name: the display name for the user receiving the notification - unsubscribe_link: the link users can click to unsubscribe from email notifications - summary_text: a summary of the notification(s). The text used can be customised - by configuring the various settings in the `email.subjects` section of the - configuration file. - rooms: a list of rooms containing events to include in the email - reason: information on the event that triggered the email to be sent - """ - - app_name: str - server_name: str - link: str - user_display_name: str - unsubscribe_link: str - summary_text: str - rooms: List[RoomVars] - reason: EmailReason diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 7d26954244ea..154e5b7028e9 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -86,7 +86,7 @@ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", - "ijson>=3.1", + "ijson>=3.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index daacc34ceac4..0db419ea57fb 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -46,8 +46,6 @@ async def _serialize_payload( is_guest, is_appservice_ghost, should_issue_refresh_token, - auth_provider_id, - auth_provider_session_id, ): """ Args: @@ -65,8 +63,6 @@ async def _serialize_payload( "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, "should_issue_refresh_token": should_issue_refresh_token, - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, } async def _handle_request(self, request, user_id): @@ -77,8 +73,6 @@ async def _handle_request(self, request, user_id): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] should_issue_refresh_token = content["should_issue_refresh_token"] - auth_provider_id = content["auth_provider_id"] - auth_provider_session_id = content["auth_provider_session_id"] res = await self.registration_handler.register_device_inner( user_id, @@ -87,8 +81,6 @@ async def _handle_request(self, request, user_id): is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, ) return 200, res diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index fa132d10b414..8c1bf9227ac6 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -14,18 +14,10 @@ from typing import List, Optional, Tuple from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id +from synapse.storage.util.id_generators import _load_current_id -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - +class SlavedIdTracker: def __init__( self, db_conn: LoggingDatabaseConnection, @@ -44,7 +36,17 @@ def advance(self, instance_name: Optional[str], new_id: int): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self) -> int: + """ + + Returns: + int + """ return self._current def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ return self.get_current_token() diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 7541e21de9dd..4d5f86286213 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -24,6 +25,9 @@ def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): + # We assert this for the benefit of mypy + assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) + if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index a390cfcb74d5..a030e9299edc 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import attr @@ -157,7 +157,7 @@ async def _update_function( # now we fetch up to that many rows from the events table - event_rows = await self._store.get_all_new_forward_event_rows( + event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows( instance_name, from_token, current_token, target_row_count ) @@ -191,7 +191,7 @@ async def _update_function( # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. - ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( + ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows( instance_name, from_token, upper_limit ) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c499afd4be57..ee4a5e481bee 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -17,7 +17,6 @@ import logging import platform -from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple import synapse @@ -40,10 +39,6 @@ EventReportDetailRestServlet, EventReportsRestServlet, ) -from synapse.rest.admin.federation import ( - DestinationsRestServlet, - ListDestinationsRestServlet, -) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.registration_tokens import ( @@ -103,7 +98,7 @@ def __init__(self, hs: "HomeServer"): } def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - return HTTPStatus.OK, self.res + return 200, self.res class PurgeHistoryRestServlet(RestServlet): @@ -135,7 +130,7 @@ async def on_POST( event = await self.store.get_event(event_id) if event.room_id != room_id: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.") + raise SynapseError(400, "Event is for wrong room.") # RoomStreamToken expects [int] not Optional[int] assert event.internal_metadata.stream_ordering is not None @@ -149,9 +144,7 @@ async def on_POST( ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "purge_up_to_ts must be an int", - errcode=Codes.BAD_JSON, + 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON ) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) @@ -167,9 +160,7 @@ async def on_POST( stream_ordering, ) raise SynapseError( - HTTPStatus.NOT_FOUND, - "there is no event to be purged", - errcode=Codes.NOT_FOUND, + 404, "there is no event to be purged", errcode=Codes.NOT_FOUND ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) @@ -182,7 +173,7 @@ async def on_POST( ) else: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "must specify purge_up_to_event_id or purge_up_to_ts", errcode=Codes.BAD_JSON, ) @@ -191,7 +182,7 @@ async def on_POST( room_id, token, delete_local_events=delete_local_events ) - return HTTPStatus.OK, {"purge_id": purge_id} + return 200, {"purge_id": purge_id} class PurgeHistoryStatusRestServlet(RestServlet): @@ -210,7 +201,7 @@ async def on_GET( if purge_status is None: raise NotFoundError("purge id '%s' not found" % purge_id) - return HTTPStatus.OK, purge_status.asdict() + return 200, purge_status.asdict() ######################################################################################## @@ -265,8 +256,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) - DestinationsRestServlet(hs).register(http_server) - ListDestinationsRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 399b205aaf81..d9a2f6ca157f 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -13,7 +13,6 @@ # limitations under the License. import re -from http import HTTPStatus from typing import Iterable, Pattern from synapse.api.auth import Auth @@ -63,4 +62,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: """ is_admin = await auth.is_server_admin(user_id) if not is_admin: - raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + raise AuthError(403, "You are not a server admin") diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 2e5a6600d337..80fbf32f17df 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError @@ -54,7 +53,7 @@ async def on_GET( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") + raise SynapseError(400, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -63,7 +62,7 @@ async def on_GET( device = await self.device_handler.get_device( target_user.to_string(), device_id ) - return HTTPStatus.OK, device + return 200, device async def on_DELETE( self, request: SynapseRequest, user_id: str, device_id: str @@ -72,14 +71,14 @@ async def on_DELETE( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") + raise SynapseError(400, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") await self.device_handler.delete_device(target_user.to_string(), device_id) - return HTTPStatus.OK, {} + return 200, {} async def on_PUT( self, request: SynapseRequest, user_id: str, device_id: str @@ -88,7 +87,7 @@ async def on_PUT( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") + raise SynapseError(400, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -98,7 +97,7 @@ async def on_PUT( await self.device_handler.update_device( target_user.to_string(), device_id, body ) - return HTTPStatus.OK, {} + return 200, {} class DevicesRestServlet(RestServlet): @@ -125,14 +124,14 @@ async def on_GET( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") + raise SynapseError(400, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") devices = await self.device_handler.get_devices_by_user(target_user.to_string()) - return HTTPStatus.OK, {"devices": devices, "total": len(devices)} + return 200, {"devices": devices, "total": len(devices)} class DeleteDevicesRestServlet(RestServlet): @@ -156,7 +155,7 @@ async def on_POST( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") + raise SynapseError(400, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -168,4 +167,4 @@ async def on_POST( await self.device_handler.delete_devices( target_user.to_string(), body["devices"] ) - return HTTPStatus.OK, {} + return 200, {} diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 5ee8b11110e0..bbfcaf723b7b 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -67,23 +66,21 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if start < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "The start parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "The limit parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if direction not in ("f", "b"): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM ) event_reports, total = await self.store.get_event_reports_paginate( @@ -93,7 +90,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if (start + limit) < total: ret["next_token"] = start + len(event_reports) - return HTTPStatus.OK, ret + return 200, ret class EventReportDetailRestServlet(RestServlet): @@ -130,17 +127,13 @@ async def on_GET( try: resolved_report_id = int(report_id) except ValueError: - raise SynapseError( - HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM - ) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) if resolved_report_id < 0: - raise SynapseError( - HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM - ) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) ret = await self.store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") - return HTTPStatus.OK, ret + return 200, ret diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py deleted file mode 100644 index 744687be35fc..000000000000 --- a/synapse/rest/admin/federation.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple - -from synapse.api.errors import Codes, NotFoundError, SynapseError -from synapse.http.servlet import RestServlet, parse_integer, parse_string -from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin -from synapse.storage.databases.main.transactions import DestinationSortOrder -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class ListDestinationsRestServlet(RestServlet): - """Get request to list all destinations. - This needs user to have administrator access in Synapse. - - GET /_synapse/admin/v1/federation/destinations?from=0&limit=10 - - returns: - 200 OK with list of destinations if success otherwise an error. - - The parameters `from` and `limit` are required only for pagination. - By default, a `limit` of 100 is used. - The parameter `destination` can be used to filter by destination. - The parameter `order_by` can be used to order the result. - """ - - PATTERNS = admin_patterns("/federation/destinations$") - - def __init__(self, hs: "HomeServer"): - self._auth = hs.get_auth() - self._store = hs.get_datastore() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) - - start = parse_integer(request, "from", default=0) - limit = parse_integer(request, "limit", default=100) - - if start < 0: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Query parameter from must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - if limit < 0: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Query parameter limit must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - destination = parse_string(request, "destination") - - order_by = parse_string( - request, - "order_by", - default=DestinationSortOrder.DESTINATION.value, - allowed_values=[dest.value for dest in DestinationSortOrder], - ) - - direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) - - destinations, total = await self._store.get_destinations_paginate( - start, limit, destination, order_by, direction - ) - response = {"destinations": destinations, "total": total} - if (start + limit) < total: - response["next_token"] = str(start + len(destinations)) - - return HTTPStatus.OK, response - - -class DestinationsRestServlet(RestServlet): - """Get details of a destination. - This needs user to have administrator access in Synapse. - - GET /_synapse/admin/v1/federation/destinations/ - - returns: - 200 OK with details of a destination if success otherwise an error. - """ - - PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$") - - def __init__(self, hs: "HomeServer"): - self._auth = hs.get_auth() - self._store = hs.get_datastore() - - async def on_GET( - self, request: SynapseRequest, destination: str - ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) - - destination_retry_timings = await self._store.get_destination_retry_timings( - destination - ) - - if not destination_retry_timings: - raise NotFoundError("Unknown destination") - - last_successful_stream_ordering = ( - await self._store.get_destination_last_successful_stream_ordering( - destination - ) - ) - - response = { - "destination": destination, - "failure_ts": destination_retry_timings.failure_ts, - "retry_last_ts": destination_retry_timings.retry_last_ts, - "retry_interval": destination_retry_timings.retry_interval, - "last_successful_stream_ordering": last_successful_stream_ordering, - } - - return HTTPStatus.OK, response diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index a27110388f4f..68a3ba3cb7ac 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError @@ -44,7 +43,7 @@ async def on_POST( await assert_user_is_admin(self.auth, requester.user) if not self.is_mine_id(group_id): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups") + raise SynapseError(400, "Can only delete local groups") await self.group_server.delete_group(group_id, requester.user.to_string()) - return HTTPStatus.OK, {} + return 200, {} diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 9e23e2d8fc00..30a687d234e3 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -14,7 +14,6 @@ # limitations under the License. import logging -from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError @@ -63,7 +62,7 @@ async def on_POST( room_id, requester.user.to_string() ) - return HTTPStatus.OK, {"num_quarantined": num_quarantined} + return 200, {"num_quarantined": num_quarantined} class QuarantineMediaByUser(RestServlet): @@ -90,7 +89,7 @@ async def on_POST( user_id, requester.user.to_string() ) - return HTTPStatus.OK, {"num_quarantined": num_quarantined} + return 200, {"num_quarantined": num_quarantined} class QuarantineMediaByID(RestServlet): @@ -119,7 +118,7 @@ async def on_POST( server_name, media_id, requester.user.to_string() ) - return HTTPStatus.OK, {} + return 200, {} class UnquarantineMediaByID(RestServlet): @@ -148,7 +147,7 @@ async def on_POST( # Remove from quarantine this media id await self.store.quarantine_media_by_id(server_name, media_id, None) - return HTTPStatus.OK, {} + return 200, {} class ProtectMediaByID(RestServlet): @@ -171,7 +170,7 @@ async def on_POST( # Protect this media id await self.store.mark_local_media_as_safe(media_id, safe=True) - return HTTPStatus.OK, {} + return 200, {} class UnprotectMediaByID(RestServlet): @@ -194,7 +193,7 @@ async def on_POST( # Unprotect this media id await self.store.mark_local_media_as_safe(media_id, safe=False) - return HTTPStatus.OK, {} + return 200, {} class ListMediaInRoom(RestServlet): @@ -212,11 +211,11 @@ async def on_GET( requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: - raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + raise AuthError(403, "You are not a server admin") local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) - return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs} + return 200, {"local": local_mxcs, "remote": remote_mxcs} class PurgeMediaCacheRestServlet(RestServlet): @@ -234,13 +233,13 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if before_ts < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, @@ -248,7 +247,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ret = await self.media_repository.delete_old_remote_media(before_ts) - return HTTPStatus.OK, ret + return 200, ret class DeleteMediaByID(RestServlet): @@ -268,7 +267,7 @@ async def on_DELETE( await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") + raise SynapseError(400, "Can only delete local media") if await self.store.get_local_media(media_id) is None: raise NotFoundError("Unknown media") @@ -278,7 +277,7 @@ async def on_DELETE( deleted_media, total = await self.media_repository.delete_local_media_ids( [media_id] ) - return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} + return 200, {"deleted_media": deleted_media, "total": total} class DeleteMediaByDateSize(RestServlet): @@ -305,26 +304,26 @@ async def on_POST( if before_ts < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, ) if size_gt < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter size_gt must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if self.server_name != server_name: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") + raise SynapseError(400, "Can only delete local media") logging.info( "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s" @@ -334,7 +333,7 @@ async def on_POST( deleted_media, total = await self.media_repository.delete_old_local_media( before_ts, size_gt, keep_profiles ) - return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} + return 200, {"deleted_media": deleted_media, "total": total} class UserMediaRestServlet(RestServlet): @@ -370,7 +369,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") + raise SynapseError(400, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -381,14 +380,14 @@ async def on_GET( if start < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -426,7 +425,7 @@ async def on_GET( if (start + limit) < total: ret["next_token"] = start + len(media) - return HTTPStatus.OK, ret + return 200, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -437,7 +436,7 @@ async def on_DELETE( await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") + raise SynapseError(400, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -448,14 +447,14 @@ async def on_DELETE( if start < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -493,7 +492,7 @@ async def on_DELETE( ([row["media_id"] for row in media]) ) - return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} + return 200, {"deleted_media": deleted_media, "total": total} def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 891b98c0888a..aba48f6e7bab 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -14,7 +14,6 @@ import logging import string -from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -78,7 +77,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return HTTPStatus.OK, {"registration_tokens": token_list} + return 200, {"registration_tokens": token_list} class NewRegistrationTokenRestServlet(RestServlet): @@ -124,20 +123,16 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if "token" in body: token = body["token"] if not isinstance(token, str): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "token must be a string", - Codes.INVALID_PARAM, - ) + raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) if not (0 < len(token) <= 64): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "token must not be empty and must not be longer than 64 characters", Codes.INVALID_PARAM, ) if not set(token).issubset(self.allowed_chars_set): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "token must consist only of characters matched by the regex [A-Za-z0-9-_]", Codes.INVALID_PARAM, ) @@ -147,13 +142,11 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: length = body.get("length", 16) if not isinstance(length, int): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "length must be an integer", - Codes.INVALID_PARAM, + 400, "length must be an integer", Codes.INVALID_PARAM ) if not (0 < length <= 64): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "length must be greater than zero and not greater than 64", Codes.INVALID_PARAM, ) @@ -169,7 +162,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -177,15 +170,11 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: expiry_time = body.get("expiry_time", None) if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "expiry_time must be an integer or null", - Codes.INVALID_PARAM, + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "expiry_time must not be in the past", - Codes.INVALID_PARAM, + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM ) created = await self.store.create_registration_token( @@ -193,9 +182,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) if not created: raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"Token already exists: {token}", - Codes.INVALID_PARAM, + 400, f"Token already exists: {token}", Codes.INVALID_PARAM ) resp = { @@ -205,7 +192,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "completed": 0, "expiry_time": expiry_time, } - return HTTPStatus.OK, resp + return 200, resp class RegistrationTokenRestServlet(RestServlet): @@ -274,7 +261,7 @@ async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return HTTPStatus.OK, token_info + return 200, token_info async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Update a registration token.""" @@ -290,7 +277,7 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -300,15 +287,11 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi expiry_time = body["expiry_time"] if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "expiry_time must be an integer or null", - Codes.INVALID_PARAM, + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "expiry_time must not be in the past", - Codes.INVALID_PARAM, + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM ) new_attributes["expiry_time"] = expiry_time @@ -324,7 +307,7 @@ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDi if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return HTTPStatus.OK, token_info + return 200, token_info async def on_DELETE( self, request: SynapseRequest, token: str @@ -333,6 +316,6 @@ async def on_DELETE( await assert_requester_is_admin(self.auth, request) if await self.store.delete_registration_token(token): - return HTTPStatus.OK, {} + return 200, {} raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 829e86675aba..a89dda1ba5b2 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -102,9 +102,10 @@ async def on_DELETE( ) if not RoomID.is_valid(room_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) - ) + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + if not await self._store.get_room(room_id): + raise NotFoundError("Unknown room id %s" % (room_id,)) delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, @@ -117,7 +118,7 @@ async def on_DELETE( force_purge=force_purge, ) - return HTTPStatus.OK, {"delete_id": delete_id} + return 200, {"delete_id": delete_id} class DeleteRoomStatusByRoomIdRestServlet(RestServlet): @@ -136,9 +137,7 @@ async def on_GET( await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) - ) + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) if delete_ids is None: @@ -154,7 +153,7 @@ async def on_GET( **delete.asdict(), } ] - return HTTPStatus.OK, {"results": cast(JsonDict, response)} + return 200, {"results": cast(JsonDict, response)} class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): @@ -176,7 +175,7 @@ async def on_GET( if delete_status is None: raise NotFoundError("delete id '%s' not found" % delete_id) - return HTTPStatus.OK, cast(JsonDict, delete_status.asdict()) + return 200, cast(JsonDict, delete_status.asdict()) class ListRoomRestServlet(RestServlet): @@ -218,7 +217,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: RoomSortOrder.STATE_EVENTS.value, ): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Unknown value for order_by: %s" % (order_by,), errcode=Codes.INVALID_PARAM, ) @@ -226,7 +225,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "search_term cannot be an empty string", errcode=Codes.INVALID_PARAM, ) @@ -234,9 +233,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM ) reverse_order = True if direction == "b" else False @@ -268,7 +265,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: else: response["prev_batch"] = 0 - return HTTPStatus.OK, response + return 200, response class RoomRestServlet(RestServlet): @@ -313,7 +310,7 @@ async def on_GET( members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - return HTTPStatus.OK, ret + return 200, ret async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -389,7 +386,7 @@ async def _delete_room( # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 # for some discussion on why this is necessary. Either way, # `ret` is an opaque dictionary blob as far as the rest of the app cares. - return HTTPStatus.OK, cast(JsonDict, ret) + return 200, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): @@ -416,7 +413,7 @@ async def on_GET( members = await self.store.get_users_in_room(room_id) ret = {"members": members, "total": len(members)} - return HTTPStatus.OK, ret + return 200, ret class RoomStateRestServlet(RestServlet): @@ -446,10 +443,16 @@ async def on_GET( event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events(events.values(), now) + room_state = await self._event_serializer.serialize_events( + events.values(), + now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_relations=False, + ) ret = {"state": room_state} - return HTTPStatus.OK, ret + return 200, ret class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): @@ -478,10 +481,7 @@ async def on_POST( target_user = UserID.from_string(content["user_id"]) if not self.hs.is_mine(target_user): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "This endpoint can only be used with local users", - ) + raise SynapseError(400, "This endpoint can only be used with local users") if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -527,7 +527,7 @@ async def on_POST( ratelimit=False, ) - return HTTPStatus.OK, {"room_id": room_id} + return 200, {"room_id": room_id} class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): @@ -568,7 +568,7 @@ async def on_POST( # Figure out which local users currently have power in the room, if any. room_state = await self.state_handler.get_current_state(room_id) if not room_state: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") + raise SynapseError(400, "Server not in room") create_event = room_state[(EventTypes.Create, "")] power_levels = room_state.get((EventTypes.PowerLevels, "")) @@ -582,9 +582,7 @@ async def on_POST( admin_users.sort(key=lambda user: user_power[user]) if not admin_users: - raise SynapseError( - HTTPStatus.BAD_REQUEST, "No local admin user in room" - ) + raise SynapseError(400, "No local admin user in room") admin_user_id = None @@ -601,7 +599,7 @@ async def on_POST( if not admin_user_id: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "No local admin user in room", ) @@ -612,7 +610,7 @@ async def on_POST( admin_user_id = create_event.sender if not self.is_mine_id(admin_user_id): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "No local admin user in room", ) @@ -641,8 +639,7 @@ async def on_POST( except AuthError: # The admin user we found turned out not to have enough power. raise SynapseError( - HTTPStatus.BAD_REQUEST, - "No local admin user in room with power to update power levels.", + 400, "No local admin user in room with power to update power levels." ) # Now we check if the user we're granting admin rights to is already in @@ -656,7 +653,7 @@ async def on_POST( ) if is_joined: - return HTTPStatus.OK, {} + return 200, {} join_rules = room_state.get((EventTypes.JoinRules, "")) is_public = False @@ -664,7 +661,7 @@ async def on_POST( is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC if is_public: - return HTTPStatus.OK, {} + return 200, {} await self.room_member_handler.update_membership( fake_requester, @@ -673,7 +670,7 @@ async def on_POST( action=Membership.INVITE, ) - return HTTPStatus.OK, {} + return 200, {} class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): @@ -705,7 +702,7 @@ async def on_DELETE( room_id, _ = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) - return HTTPStatus.OK, {"deleted": deleted_count} + return 200, {"deleted": deleted_count} async def on_GET( self, request: SynapseRequest, room_identifier: str @@ -716,7 +713,7 @@ async def on_GET( room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return HTTPStatus.OK, {"count": len(extremities), "results": extremities} + return 200, {"count": len(extremities), "results": extremities} class RoomEventContextServlet(RestServlet): @@ -765,9 +762,7 @@ async def on_GET( ) if not results: - raise SynapseError( - HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() results["events_before"] = await self._event_serializer.serialize_events( @@ -780,10 +775,13 @@ async def on_GET( results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], time_now + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_relations=False, ) - return HTTPStatus.OK, results + return 200, results class BlockRoomRestServlet(RestServlet): diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b295fb078bc7..19f84f33f22f 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes @@ -83,15 +82,11 @@ async def on_POST( # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). if not self.server_notices_manager.is_enabled(): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server" - ) + raise SynapseError(400, "Server notices are not enabled on this server") target_user = UserID.from_string(body["user_id"]) if not self.hs.is_mine(target_user): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" - ) + raise SynapseError(400, "Server notices can only be sent to local users") if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -104,7 +99,7 @@ async def on_POST( txn_id=txn_id, ) - return HTTPStatus.OK, {"event_id": event.event_id} + return 200, {"event_id": event.event_id} def on_PUT( self, request: SynapseRequest, txn_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index ca41fd45f2bd..948de94ccd6a 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError @@ -54,7 +53,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: UserSortOrder.DISPLAYNAME.value, ): raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Unknown value for order_by: %s" % (order_by,), errcode=Codes.INVALID_PARAM, ) @@ -62,7 +61,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: start = parse_integer(request, "from", default=0) if start < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -70,7 +69,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: limit = parse_integer(request, "limit", default=100) if limit < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -78,7 +77,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: from_ts = parse_integer(request, "from_ts", default=0) if from_ts < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter from_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -87,13 +86,13 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if until_ts is not None: if until_ts < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter until_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if until_ts <= from_ts: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter until_ts must be greater than from_ts.", errcode=Codes.INVALID_PARAM, ) @@ -101,7 +100,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: search_term = parse_string(request, "search_term") if search_term == "": raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter search_term cannot be an empty string.", errcode=Codes.INVALID_PARAM, ) @@ -109,9 +108,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown direction: %s" % (direction,), - errcode=Codes.INVALID_PARAM, + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM ) users_media, total = await self.store.get_users_media_usage_paginate( @@ -121,4 +118,4 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if (start + limit) < total: ret["next_token"] = start + len(users_media) - return HTTPStatus.OK, ret + return 200, ret diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2a60b602b1f8..ccd9a2a17580 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -79,14 +79,14 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if start < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -122,7 +122,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if (start + limit) < total: ret["next_token"] = str(start + len(users)) - return HTTPStatus.OK, ret + return 200, ret class UserRestServletV2(RestServlet): @@ -172,14 +172,14 @@ async def on_GET( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") + raise SynapseError(400, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) if not ret: raise NotFoundError("User not found") - return HTTPStatus.OK, ret + return 200, ret async def on_PUT( self, request: SynapseRequest, user_id: str @@ -191,10 +191,7 @@ async def on_PUT( body = parse_json_object_from_request(request) if not self.hs.is_mine(target_user): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "This endpoint can only be used with local users", - ) + raise SynapseError(400, "This endpoint can only be used with local users") user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() @@ -213,7 +210,7 @@ async def on_PUT( user_type = body.get("user_type", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") + raise SynapseError(400, "Invalid user type") set_admin_to = body.get("admin", False) if not isinstance(set_admin_to, bool): @@ -226,13 +223,11 @@ async def on_PUT( password = body.get("password", None) if password is not None: if not isinstance(password, str) or len(password) > 512: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") + raise SynapseError(400, "Invalid password") deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" - ) + raise SynapseError(400, "'deactivated' parameter is not of type boolean") # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: @@ -287,9 +282,7 @@ async def on_PUT( user_id, ) except ExternalIDReuseException: - raise SynapseError( - HTTPStatus.CONFLICT, "External id is already in use." - ) + raise SynapseError(409, "External id is already in use.") if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -300,9 +293,7 @@ async def on_PUT( if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: - raise SynapseError( - HTTPStatus.BAD_REQUEST, "You may not demote yourself." - ) + raise SynapseError(400, "You may not demote yourself.") await self.store.set_server_admin(target_user, set_admin_to) @@ -328,8 +319,7 @@ async def on_PUT( and self.auth_handler.can_change_password() ): raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Must provide a password to re-activate an account.", + 400, "Must provide a password to re-activate an account." ) await self.deactivate_account_handler.activate_account( @@ -342,7 +332,7 @@ async def on_PUT( user = await self.admin_handler.get_user(target_user) assert user is not None - return HTTPStatus.OK, user + return 200, user else: # create user displayname = body.get("displayname", None) @@ -391,9 +381,7 @@ async def on_PUT( user_id, ) except ExternalIDReuseException: - raise SynapseError( - HTTPStatus.CONFLICT, "External id is already in use." - ) + raise SynapseError(409, "External id is already in use.") if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -441,61 +429,51 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: nonce = secrets.token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return HTTPStatus.OK, {"nonce": nonce} + return 200, {"nonce": nonce} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() if not self.hs.config.registration.registration_shared_secret: - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled" - ) + raise SynapseError(400, "Shared secret registration is not enabled") body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "nonce must be specified", - errcode=Codes.BAD_JSON, - ) + raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce") + raise SynapseError(400, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - HTTPStatus.BAD_REQUEST, - "username must be specified", - errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) else: if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") + raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") if b"\x00" in username: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") + raise SynapseError(400, "Invalid username") if "password" not in body: raise SynapseError( - HTTPStatus.BAD_REQUEST, - "password must be specified", - errcode=Codes.BAD_JSON, + 400, "password must be specified", errcode=Codes.BAD_JSON ) else: password = body["password"] if not isinstance(password, str) or len(password) > 512: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") + raise SynapseError(400, "Invalid password") password_bytes = password.encode("utf-8") if b"\x00" in password_bytes: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") + raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -504,12 +482,10 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: displayname = body.get("displayname", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") + raise SynapseError(400, "Invalid user type") if "mac" not in body: - raise SynapseError( - HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON - ) + raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) got_mac = body["mac"] @@ -531,7 +507,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: want_mac = want_mac_builder.hexdigest() if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): - raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") + raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.register import RegisterRestServlet @@ -548,7 +524,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) result = await register._create_registration_details(user_id, body) - return HTTPStatus.OK, result + return 200, result class WhoisRestServlet(RestServlet): @@ -576,11 +552,11 @@ async def on_GET( await assert_user_is_admin(self.auth, auth_user) if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") + raise SynapseError(400, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) - return HTTPStatus.OK, ret + return 200, ret class DeactivateAccountRestServlet(RestServlet): @@ -599,9 +575,7 @@ async def on_POST( await assert_user_is_admin(self.auth, requester.user) if not self.is_mine(UserID.from_string(target_user_id)): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Can only deactivate local users" - ) + raise SynapseError(400, "Can only deactivate local users") if not await self.store.get_user_by_id(target_user_id): raise NotFoundError("User not found") @@ -623,7 +597,7 @@ async def on_POST( else: id_server_unbind_result = "no-support" - return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result} + return 200, {"id_server_unbind_result": id_server_unbind_result} class AccountValidityRenewServlet(RestServlet): @@ -646,7 +620,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if "user_id" not in body: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "Missing property 'user_id' in the request body", ) @@ -657,7 +631,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) res = {"expiration_ts": expiration_ts} - return HTTPStatus.OK, res + return 200, res class ResetPasswordRestServlet(RestServlet): @@ -704,7 +678,7 @@ async def on_POST( await self._set_password_handler.set_password( target_user_id, new_password_hash, logout_devices, requester ) - return HTTPStatus.OK, {} + return 200, {} class SearchUsersRestServlet(RestServlet): @@ -738,16 +712,16 @@ async def on_GET( # To allow all users to get the users list # if not is_admin and target_user != auth_user: - # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + # raise AuthError(403, "You are not a server admin") if not self.hs.is_mine(target_user): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") + raise SynapseError(400, "Can only users a local user") term = parse_string(request, "term", required=True) logger.info("term: %s ", term) ret = await self.store.search_users(term) - return HTTPStatus.OK, ret + return 200, ret class UserAdminServlet(RestServlet): @@ -791,14 +765,11 @@ async def on_GET( target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Only local users can be admins of this homeserver", - ) + raise SynapseError(400, "Only local users can be admins of this homeserver") is_admin = await self.store.is_server_admin(target_user) - return HTTPStatus.OK, {"admin": is_admin} + return 200, {"admin": is_admin} async def on_PUT( self, request: SynapseRequest, user_id: str @@ -814,19 +785,16 @@ async def on_PUT( assert_params_in_dict(body, ["admin"]) if not self.hs.is_mine(target_user): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Only local users can be admins of this homeserver", - ) + raise SynapseError(400, "Only local users can be admins of this homeserver") set_admin_to = bool(body["admin"]) if target_user == auth_user and not set_admin_to: - raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.") + raise SynapseError(400, "You may not demote yourself.") await self.store.set_server_admin(target_user, set_admin_to) - return HTTPStatus.OK, {} + return 200, {} class UserMembershipRestServlet(RestServlet): @@ -848,7 +816,7 @@ async def on_GET( room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} - return HTTPStatus.OK, ret + return 200, ret class PushersRestServlet(RestServlet): @@ -877,7 +845,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -886,10 +854,7 @@ async def on_GET( filtered_pushers = [p.as_dict() for p in pushers] - return HTTPStatus.OK, { - "pushers": filtered_pushers, - "total": len(filtered_pushers), - } + return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} class UserTokenRestServlet(RestServlet): @@ -922,22 +887,16 @@ async def on_POST( auth_user = requester.user if not self.hs.is_mine_id(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" - ) + raise SynapseError(400, "Only local users can be logged in as") body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") if valid_until_ms and not isinstance(valid_until_ms, int): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" - ) + raise SynapseError(400, "'valid_until_ms' parameter must be an int") if auth_user.to_string() == user_id: - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self" - ) + raise SynapseError(400, "Cannot use admin API to login as self") token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), @@ -946,7 +905,7 @@ async def on_POST( puppets_user_id=user_id, ) - return HTTPStatus.OK, {"access_token": token} + return 200, {"access_token": token} class ShadowBanRestServlet(RestServlet): @@ -988,13 +947,11 @@ async def on_POST( await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" - ) + raise SynapseError(400, "Only local users can be shadow-banned") await self.store.set_shadow_banned(UserID.from_string(user_id), True) - return HTTPStatus.OK, {} + return 200, {} async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -1002,13 +959,11 @@ async def on_DELETE( await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" - ) + raise SynapseError(400, "Only local users can be shadow-banned") await self.store.set_shadow_banned(UserID.from_string(user_id), False) - return HTTPStatus.OK, {} + return 200, {} class RateLimitRestServlet(RestServlet): @@ -1040,7 +995,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1061,7 +1016,7 @@ async def on_GET( else: ret = {} - return HTTPStatus.OK, ret + return 200, ret async def on_POST( self, request: SynapseRequest, user_id: str @@ -1069,9 +1024,7 @@ async def on_POST( await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" - ) + raise SynapseError(400, "Only local users can be ratelimited") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1083,14 +1036,14 @@ async def on_POST( if not isinstance(messages_per_second, int) or messages_per_second < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) if not isinstance(burst_count, int) or burst_count < 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "%r parameter must be a positive int" % (burst_count,), errcode=Codes.INVALID_PARAM, ) @@ -1106,7 +1059,7 @@ async def on_POST( "burst_count": ratelimit.burst_count, } - return HTTPStatus.OK, ret + return 200, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -1114,13 +1067,11 @@ async def on_DELETE( await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" - ) + raise SynapseError(400, "Only local users can be ratelimited") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") await self.store.delete_ratelimit_for_user(user_id) - return HTTPStatus.OK, {} + return 200, {} diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f9994658c4a3..67e03dca04c0 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -14,17 +14,7 @@ import logging import re -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Dict, - List, - Optional, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing_extensions import TypedDict @@ -38,6 +28,7 @@ from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -72,7 +63,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "m.login.application_service" APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" - REFRESH_TOKEN_PARAM = "refresh_token" + REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -90,7 +81,7 @@ def __init__(self, hs: "HomeServer"): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._refresh_tokens_enabled = ( + self._msc2918_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) @@ -163,16 +154,14 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) - # Check to see if the client requested a refresh token. - client_requested_refresh_token = login_submission.get( - LoginRestServlet.REFRESH_TOKEN_PARAM, False - ) - if not isinstance(client_requested_refresh_token, bool): - raise SynapseError(400, "`refresh_token` should be true or false.") - - should_issue_refresh_token = ( - self._refresh_tokens_enabled and client_requested_refresh_token - ) + if self._msc2918_enabled: + # Check if this login should also issue a refresh token, as per + # MSC2918 + should_issue_refresh_token = parse_boolean( + request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + ) + else: + should_issue_refresh_token = False try: if login_submission["type"] in ( @@ -302,7 +291,6 @@ async def _complete_login( ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, - auth_provider_session_id: Optional[str] = None, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -318,10 +306,10 @@ async def _complete_login( create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. - auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: result: Dictionary of account information after successful login. @@ -354,7 +342,6 @@ async def _complete_login( initial_display_name, auth_provider_id=auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, - auth_provider_session_id=auth_provider_session_id, ) result = LoginResponse( @@ -400,7 +387,6 @@ async def _do_token_login( self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, - auth_provider_session_id=res.auth_provider_session_id, ) async def _do_jwt_login( @@ -462,7 +448,9 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) + PATTERNS = client_patterns( + "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True + ) def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() @@ -470,7 +458,6 @@ def __init__(self, hs: "HomeServer"): self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) - self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -480,32 +467,21 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not isinstance(token, str): raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - now = self._clock.time_msec() - access_valid_until_ms = None - if self.refreshable_access_token_lifetime is not None: - access_valid_until_ms = now + self.refreshable_access_token_lifetime - refresh_valid_until_ms = None - if self.refresh_token_lifetime is not None: - refresh_valid_until_ms = now + self.refresh_token_lifetime - - ( - access_token, - refresh_token, - actual_access_token_expiry, - ) = await self._auth_handler.refresh_token( - token, access_valid_until_ms, refresh_valid_until_ms + valid_until_ms = ( + self._clock.time_msec() + self.refreshable_access_token_lifetime + ) + access_token, refresh_token = await self._auth_handler.refresh_token( + token, valid_until_ms + ) + expires_in_ms = valid_until_ms - self._clock.time_msec() + return ( + 200, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in_ms": expires_in_ms, + }, ) - - response: Dict[str, Union[str, int]] = { - "access_token": access_token, - "refresh_token": refresh_token, - } - - # expires_in_ms is only present if the token expires - if actual_access_token_expiry is not None: - response["expires_in_ms"] = actual_access_token_expiry - now - - return 200, response class SsoRedirectServlet(RestServlet): @@ -513,7 +489,7 @@ class SsoRedirectServlet(RestServlet): re.compile( "^" + CLIENT_API_PREFIX - + "/(r0|v3)/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" + + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" ) ] diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 8b56c76aed66..d2b11e39d972 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -41,6 +41,7 @@ from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_json_object_from_request, parse_string, ) @@ -419,7 +420,7 @@ def __init__(self, hs: "HomeServer"): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._refresh_tokens_enabled = ( + self._msc2918_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) @@ -445,15 +446,14 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: f"Do not understand membership kind: {kind}", ) - # Check if the clients wishes for this registration to issue a refresh - # token. - client_requested_refresh_tokens = body.get("refresh_token", False) - if not isinstance(client_requested_refresh_tokens, bool): - raise SynapseError(400, "`refresh_token` should be true or false.") - - should_issue_refresh_token = ( - self._refresh_tokens_enabled and client_requested_refresh_tokens - ) + if self._msc2918_enabled: + # Check if this registration should also issue a refresh token, as + # per MSC2918 + should_issue_refresh_token = parse_boolean( + request, name="org.matrix.msc2918.refresh_token", default=False + ) + else: + should_issue_refresh_token = False # Pull out the provided username and do basic sanity checks early since # the auth layer will store these in sessions. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index fc4e6921c5e6..45e9f1dd9022 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,14 +224,18 @@ async def on_GET( ) now = self.clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. + # We set bundle_relations to False when retrieving the original + # event because we want the content before relations were applied to + # it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + event, now, bundle_relations=False + ) + # Similarly, we don't allow relations to be applied to relations, so we + # return the original relations without any aggregations on top of them + # here. + serialized_events = await self._event_serializer.serialize_events( + events, now, bundle_relations=False ) - # The relations returned for the requested event do include their - # bundled aggregations. - serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() return_value["chunk"] = serialized_events diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index f48e2e6ca248..955d4e8641fe 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -716,7 +716,10 @@ async def on_GET( results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], time_now + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_relations=False, ) return 200, results @@ -1067,62 +1070,6 @@ def register_txn_path( ) -class TimestampLookupRestServlet(RestServlet): - """ - API endpoint to fetch the `event_id` of the closest event to the given - timestamp (`ts` query parameter) in the given direction (`dir` query - parameter). - - Useful for cases like jump to date so you can start paginating messages from - a given date in the archive. - - `ts` is a timestamp in milliseconds where we will find the closest event in - the given direction. - - `dir` can be `f` or `b` to indicate forwards and backwards in time from the - given timestamp. - - GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= - { - "event_id": ... - } - """ - - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3030" - "/rooms/(?P[^/]*)/timestamp_to_event$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._store = hs.get_datastore() - self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await self._auth.check_user_in_room(room_id, requester.user.to_string()) - - timestamp = parse_integer(request, "ts", required=True) - direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) - - ( - event_id, - origin_server_ts, - ) = await self.timestamp_lookup_handler.get_event_for_timestamp( - requester, room_id, timestamp, direction - ) - - return 200, { - "event_id": event_id, - "origin_server_ts": origin_server_ts, - } - - class RoomSpaceSummaryRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1193,7 +1140,7 @@ async def on_POST( class RoomHierarchyRestServlet(RestServlet): PATTERNS = ( re.compile( - "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" + "^/_matrix/client/unstable/org.matrix.msc2946" "/rooms/(?P[^/]*)/hierarchy$" ), ) @@ -1221,7 +1168,7 @@ async def on_GET( ) return 200, await self._room_summary_handler.get_room_hierarchy( - requester, + requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), max_depth=max_depth, @@ -1292,8 +1239,6 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) - if hs.config.experimental.msc3030_enabled: - TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 88e4f5e0630f..b6a24857320b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -520,9 +520,9 @@ def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: return self._event_serializer.serialize_events( events, time_now=time_now, - # Don't bother to bundle aggregations if the timeline is unlimited, - # as clients will have all the necessary information. - bundle_aggregations=room.timeline.limited, + # We don't bundle "live" events, as otherwise clients + # will end up double counting annotations. + bundle_relations=False, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 1f6441c412d0..c0e15c65139d 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -43,75 +43,47 @@ def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str: ) -def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: +def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod: """Wraps a path-returning method to check that the returned path(s) do not escape the media store directory. - The path-returning method may return either a single path, or a list of paths. - The check is not expected to ever fail, unless `func` is missing a call to `_validate_path_component`, or `_validate_path_component` is buggy. Args: - relative: A boolean indicating whether the wrapped method returns paths relative - to the media store directory. + func: The `MediaFilePaths` method to wrap. The method may return either a single + path, or a list of paths. Returned paths may be either absolute or relative. Returns: - A method which will wrap a path-returning method, adding a check to ensure that - the returned path(s) lie within the media store directory. The check will raise - a `ValueError` if it fails. + The method, wrapped with a check to ensure that the returned path(s) lie within + the media store directory. Raises a `ValueError` if the check fails. """ - def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: - @functools.wraps(func) - def _wrapped( - self: "MediaFilePaths", *args: Any, **kwargs: Any - ) -> Union[str, List[str]]: - path_or_paths = func(self, *args, **kwargs) - - if isinstance(path_or_paths, list): - paths_to_check = path_or_paths - else: - paths_to_check = [path_or_paths] - - for path in paths_to_check: - # Construct the path that will ultimately be used. - # We cannot guess whether `path` is relative to the media store - # directory, since the media store directory may itself be a relative - # path. - if relative: - path = os.path.join(self.base_path, path) - normalized_path = os.path.normpath(path) - - # Now that `normpath` has eliminated `../`s and `./`s from the path, - # `os.path.commonpath` can be used to check whether it lies within the - # media store directory. - if ( - os.path.commonpath([normalized_path, self.normalized_base_path]) - != self.normalized_base_path - ): - # The path resolves to outside the media store directory, - # or `self.base_path` is `.`, which is an unlikely configuration. - raise ValueError(f"Invalid media store path: {path!r}") - - # Note that `os.path.normpath`/`abspath` has a subtle caveat: - # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a - # different path if `a/b/c` is a symlink. That is, the check above is - # not perfect and may allow a certain restricted subset of untrustworthy - # paths through. Since the check above is secondary to the main - # `_validate_path_component` checks, it's less important for it to be - # perfect. - # - # As an alternative, `os.path.realpath` will resolve symlinks, but - # proves problematic if there are symlinks inside the media store. - # eg. if `url_store/` is symlinked to elsewhere, its canonical path - # won't match that of the main media store directory. - - return path_or_paths - - return cast(GetPathMethod, _wrapped) - - return _wrap_with_jail_check_inner + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # path may be an absolute or relative path, depending on the method being + # wrapped. When "appending" an absolute path, `os.path.join` discards the + # previous path, which is desired here. + normalized_path = os.path.normpath(os.path.join(self.real_base_path, path)) + if ( + os.path.commonpath([normalized_path, self.real_base_path]) + != self.real_base_path + ): + raise ValueError(f"Invalid media store path: {path!r}") + + return path_or_paths + + return cast(GetPathMethod, _wrapped) ALLOWED_CHARACTERS = set( @@ -155,7 +127,9 @@ class MediaFilePaths: def __init__(self, primary_base_path: str): self.base_path = primary_base_path - self.normalized_base_path = os.path.normpath(self.base_path) + + # The media store directory, with all symlinks resolved. + self.real_base_path = os.path.realpath(primary_base_path) # Refuse to initialize if paths cannot be validated correctly for the current # platform. @@ -166,7 +140,7 @@ def __init__(self, primary_base_path: str): # for certain homeservers there, since ":"s aren't allowed in paths. assert os.name == "posix" - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join( "local_content", @@ -177,7 +151,7 @@ def local_media_filepath_rel(self, media_id: str) -> str: local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def local_media_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -193,7 +167,7 @@ def local_media_thumbnail_rel( local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) - @_wrap_with_jail_check(relative=False) + @_wrap_with_jail_check def local_media_thumbnail_dir(self, media_id: str) -> str: """ Retrieve the local store path of thumbnails of a given media_id @@ -211,7 +185,7 @@ def local_media_thumbnail_dir(self, media_id: str) -> str: _validate_path_component(media_id[4:]), ) - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", @@ -223,7 +197,7 @@ def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def remote_media_thumbnail_rel( self, server_name: str, @@ -249,7 +223,7 @@ def remote_media_thumbnail_rel( # Legacy path that was used to store thumbnails previously. # Should be removed after some time, when most of the thumbnails are stored # using the new path. - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def remote_media_thumbnail_rel_legacy( self, server_name: str, file_id: str, width: int, height: int, content_type: str ) -> str: @@ -264,7 +238,6 @@ def remote_media_thumbnail_rel_legacy( _validate_path_component(file_name), ) - @_wrap_with_jail_check(relative=False) def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, @@ -275,7 +248,7 @@ def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: _validate_path_component(file_id[4:]), ) - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form @@ -295,7 +268,7 @@ def url_cache_filepath_rel(self, media_id: str) -> str: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - @_wrap_with_jail_check(relative=False) + @_wrap_with_jail_check def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): @@ -317,7 +290,7 @@ def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: ), ] - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def url_cache_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -345,7 +318,7 @@ def url_cache_thumbnail_rel( url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - @_wrap_with_jail_check(relative=True) + @_wrap_with_jail_check def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -368,7 +341,7 @@ def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: url_cache_thumbnail_directory_rel ) - @_wrap_with_jail_check(relative=False) + @_wrap_with_jail_check def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form diff --git a/synapse/server.py b/synapse/server.py index 185e40e4da0f..877eba6c0803 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -97,7 +97,6 @@ RoomContextHandler, RoomCreationHandler, RoomShutdownHandler, - TimestampLookupHandler, ) from synapse.handlers.room_batch import RoomBatchHandler from synapse.handlers.room_list import RoomListHandler @@ -729,10 +728,6 @@ def get_pagination_handler(self) -> PaginationHandler: def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) - @cache_in_self - def get_timestamp_lookup_handler(self) -> TimestampLookupHandler: - return TimestampLookupHandler(self) - @cache_in_self def get_registration_handler(self) -> RegistrationHandler: return RegistrationHandler(self) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 446204dbe52f..1605411b0087 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -764,7 +764,7 @@ class StateResolutionStore: store: "DataStore" def get_events( - self, event_ids: Collection[str], allow_rejected: bool = False + self, event_ids: Iterable[str], allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 499a32820185..6edadea550d2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -17,7 +17,6 @@ from typing import ( Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -45,7 +44,7 @@ async def resolve_events_with_store( room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], + state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], ) -> StateMap[str]: """ Args: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 3056e64ff570..0623da9aa196 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -21,7 +21,7 @@ from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.storage.types import Connection -from synapse.types import get_domain_from_id +from synapse.types import StreamToken, get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: @@ -48,7 +48,7 @@ def process_replication_rows( self, stream_name: str, instance_name: str, - token: int, + token: StreamToken, rows: Iterable[Any], ) -> None: pass diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d64910aded33..bc8364400d2d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,22 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - AsyncContextManager, - Awaitable, - Callable, - Dict, - Iterable, - Optional, -) - -import attr +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.types import Connection from synapse.types import JsonDict -from synapse.util import Clock, json_encoder +from synapse.util import json_encoder from . import engines @@ -38,45 +28,6 @@ logger = logging.getLogger(__name__) -ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]] -DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] -MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _BackgroundUpdateHandler: - """A handler for a given background update. - - Attributes: - callback: The function to call to make progress on the background - update. - oneshot: Wether the update is likely to happen all in one go, ignoring - the supplied target duration, e.g. index creation. This is used by - the update controller to help correctly schedule the update. - """ - - callback: Callable[[JsonDict, int], Awaitable[int]] - oneshot: bool = False - - -class _BackgroundUpdateContextManager: - BACKGROUND_UPDATE_INTERVAL_MS = 1000 - BACKGROUND_UPDATE_DURATION_MS = 100 - - def __init__(self, sleep: bool, clock: Clock): - self._sleep = sleep - self._clock = clock - - async def __aenter__(self) -> int: - if self._sleep: - await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000) - - return self.BACKGROUND_UPDATE_DURATION_MS - - async def __aexit__(self, *exc) -> None: - pass - - class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" @@ -133,22 +84,20 @@ class BackgroundUpdater: MINIMUM_BACKGROUND_BATCH_SIZE = 1 DEFAULT_BACKGROUND_BATCH_SIZE = 100 + BACKGROUND_UPDATE_INTERVAL_MS = 1000 + BACKGROUND_UPDATE_DURATION_MS = 100 def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database - self._database_name = database.name() - # if a background update is currently running, its name. self._current_background_update: Optional[str] = None - self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None - self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None - self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None - self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} - self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {} + self._background_update_handlers: Dict[ + str, Callable[[JsonDict, int], Awaitable[int]] + ] = {} self._all_done = False # Whether we're currently running updates @@ -158,83 +107,6 @@ def __init__(self, hs: "HomeServer", database: "DatabasePool"): # enable/disable background updates via the admin API. self.enabled = True - def register_update_controller_callbacks( - self, - on_update: ON_UPDATE_CALLBACK, - default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, - min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, - ) -> None: - """Register callbacks from a module for each hook.""" - if self._on_update_callback is not None: - logger.warning( - "More than one module tried to register callbacks for controlling" - " background updates. Only the callbacks registered by the first module" - " (in order of appearance in Synapse's configuration file) that tried to" - " do so will be called." - ) - - return - - self._on_update_callback = on_update - - if default_batch_size is not None: - self._default_batch_size_callback = default_batch_size - - if min_batch_size is not None: - self._min_batch_size_callback = min_batch_size - - def _get_context_manager_for_update( - self, - sleep: bool, - update_name: str, - database_name: str, - oneshot: bool, - ) -> AsyncContextManager[int]: - """Get a context manager to run a background update with. - - If a module has registered a `update_handler` callback, use the context manager - it returns. - - Otherwise, returns a context manager that will return a default value, optionally - sleeping if needed. - - Args: - sleep: Whether we can sleep between updates. - update_name: The name of the update. - database_name: The name of the database the update is being run on. - oneshot: Whether the update will complete all in one go, e.g. index creation. - In such cases the returned target duration is ignored. - - Returns: - The target duration in milliseconds that the background update should run for. - - Note: this is a *target*, and an iteration may take substantially longer or - shorter. - """ - if self._on_update_callback is not None: - return self._on_update_callback(update_name, database_name, oneshot) - - return _BackgroundUpdateContextManager(sleep, self._clock) - - async def _default_batch_size(self, update_name: str, database_name: str) -> int: - """The batch size to use for the first iteration of a new background - update. - """ - if self._default_batch_size_callback is not None: - return await self._default_batch_size_callback(update_name, database_name) - - return self.DEFAULT_BACKGROUND_BATCH_SIZE - - async def _min_batch_size(self, update_name: str, database_name: str) -> int: - """A lower bound on the batch size of a new background update. - - Used to ensure that progress is always made. Must be greater than 0. - """ - if self._min_batch_size_callback is not None: - return await self._min_batch_size_callback(update_name, database_name) - - return self.MINIMUM_BACKGROUND_BATCH_SIZE - def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: """Returns the current background update, if any.""" @@ -263,8 +135,13 @@ async def run_background_updates(self, sleep: bool = True) -> None: try: logger.info("Starting background schema updates") while self.enabled: + if sleep: + await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) + try: - result = await self.do_next_background_update(sleep) + result = await self.do_next_background_update( + self.BACKGROUND_UPDATE_DURATION_MS + ) except Exception: logger.exception("Error doing update") else: @@ -326,15 +203,13 @@ async def has_completed_background_update(self, update_name: str) -> bool: return not update_exists - async def do_next_background_update(self, sleep: bool = True) -> bool: + async def do_next_background_update(self, desired_duration_ms: float) -> bool: """Does some amount of work on the next queued background update Returns once some amount of work is done. Args: - sleep: Whether to limit how quickly we run background updates or - not. - + desired_duration_ms: How long we want to spend updating. Returns: True if we have finished running all the background updates, otherwise False """ @@ -377,19 +252,7 @@ def get_background_updates_txn(txn): self._current_background_update = upd["update_name"] - # We have a background update to run, otherwise we would have returned - # early. - assert self._current_background_update is not None - update_info = self._background_update_handlers[self._current_background_update] - - async with self._get_context_manager_for_update( - sleep=sleep, - update_name=self._current_background_update, - database_name=self._database_name, - oneshot=update_info.oneshot, - ) as desired_duration_ms: - await self._do_background_update(desired_duration_ms) - + await self._do_background_update(desired_duration_ms) return False async def _do_background_update(self, desired_duration_ms: float) -> int: @@ -397,7 +260,7 @@ async def _do_background_update(self, desired_duration_ms: float) -> int: update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) - update_handler = self._background_update_handlers[update_name].callback + update_handler = self._background_update_handlers[update_name] performance = self._background_update_performance.get(update_name) @@ -410,14 +273,9 @@ async def _do_background_update(self, desired_duration_ms: float) -> int: if items_per_ms is not None: batch_size = int(desired_duration_ms * items_per_ms) # Clamp the batch size so that we always make progress - batch_size = max( - batch_size, - await self._min_batch_size(update_name, self._database_name), - ) + batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE) else: - batch_size = await self._default_batch_size( - update_name, self._database_name - ) + batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE progress_json = await self.db_pool.simple_select_one_onecol( "background_updates", @@ -436,8 +294,6 @@ async def _do_background_update(self, desired_duration_ms: float) -> int: duration_ms = time_stop - time_start - performance.update(items_updated, duration_ms) - logger.info( "Running background update %r. Processed %r items in %rms." " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", @@ -450,6 +306,8 @@ async def _do_background_update(self, desired_duration_ms: float) -> int: batch_size, ) + performance.update(items_updated, duration_ms) + return len(self._background_update_performance) def register_background_update_handler( @@ -473,9 +331,7 @@ def register_background_update_handler( update_name: The name of the update that this code handles. update_handler: The function that does the update. """ - self._background_update_handlers[update_name] = _BackgroundUpdateHandler( - update_handler - ) + self._background_update_handlers[update_name] = update_handler def register_noop_background_update(self, update_name: str) -> None: """Register a noop handler for a background update. @@ -597,9 +453,7 @@ async def updater(progress, batch_size): await self._end_background_update(update_name) return 1 - self._background_update_handlers[update_name] = _BackgroundUpdateHandler( - updater, oneshot=True - ) + self.register_background_update_handler(update_name, updater) async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 4a883dc16647..baec35ee27b2 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -143,7 +143,7 @@ async def get_appservices_by_state( A list of ApplicationServices, which may be empty. """ results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state.value}, ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -173,7 +173,7 @@ async def get_appservice_state( desc="get_appservice_state", ) if result: - return ApplicationServiceState(result.get("state")) + return result.get("state") return None async def set_appservice_state( @@ -186,7 +186,7 @@ async def set_appservice_state( state: The connectivity state to apply. """ await self.db_pool.simple_upsert( - "application_services_state", {"as_id": service.id}, {"state": state.value} + "application_services_state", {"as_id": service.id}, {"state": state} ) async def create_appservice_txn( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d5a4a661cd1a..9ccc66e589a8 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -139,27 +139,6 @@ async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: return {d["device_id"]: d for d in devices} - async def get_devices_by_auth_provider_session_id( - self, auth_provider_id: str, auth_provider_session_id: str - ) -> List[Dict[str, Any]]: - """Retrieve the list of devices associated with a SSO IdP session ID. - - Args: - auth_provider_id: The SSO IdP ID as defined in the server config - auth_provider_session_id: The session ID within the IdP - Returns: - A list of dicts containing the device_id and the user_id of each device - """ - return await self.db_pool.simple_select_list( - table="device_auth_providers", - keyvalues={ - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - retcols=("user_id", "device_id"), - desc="get_devices_by_auth_provider_session_id", - ) - @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1091,12 +1070,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): ) async def store_device( - self, - user_id: str, - device_id: str, - initial_device_display_name: Optional[str], - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1105,8 +1079,6 @@ async def store_device( device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. - auth_provider_id: The SSO IdP the user used, if any. - auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1143,18 +1115,6 @@ async def store_device( if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - if auth_provider_id and auth_provider_session_id: - await self.db_pool.simple_insert( - "device_auth_providers", - values={ - "user_id": user_id, - "device_id": device_id, - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - desc="store_device_auth_provider", - ) - self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1208,14 +1168,6 @@ def _delete_devices_txn(txn: LoggingTransaction) -> None: keyvalues={"user_id": user_id}, ) - self.db_pool.simple_delete_many_txn( - txn, - table="device_auth_providers", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id}, - ) - await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 9580a4078538..ef5d1ef01e48 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1552,9 +1552,9 @@ def delete_event_auth(txn): DELETE FROM event_auth WHERE event_id IN ( SELECT event_id FROM events - LEFT JOIN state_events AS se USING (room_id, event_id) + LEFT JOIN state_events USING (room_id, event_id) WHERE ? <= stream_ordering AND stream_ordering < ? - AND se.state_key IS null + AND state_key IS null ) """ diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3efdd0c920f6..d957e770dcd8 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json @@ -38,20 +37,6 @@ ] -class BasePushAction(TypedDict): - event_id: str - actions: List[Union[dict, str]] - - -class HttpPushAction(BasePushAction): - room_id: str - stream_ordering: int - - -class EmailPushAction(HttpPushAction): - received_ts: Optional[int] - - def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -236,7 +221,7 @@ async def get_unread_push_actions_for_user_in_range_for_http( min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[HttpPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. @@ -341,7 +326,7 @@ async def get_unread_push_actions_for_user_in_range_for_email( min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[EmailPushAction]: + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4e528612eab7..06832221adc2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -15,7 +15,7 @@ # limitations under the License. import itertools import logging -from collections import OrderedDict +from collections import OrderedDict, namedtuple from typing import ( TYPE_CHECKING, Any, @@ -41,10 +41,9 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection -from synapse.storage.util.id_generators import AbstractStreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder @@ -65,6 +64,9 @@ ) +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -106,30 +108,23 @@ def __init__( self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen + # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" - # Since we have been configured to write, we ought to have id generators, - # rather than id trackers. - assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator) - assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator) - - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen - self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen - async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], - *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], - use_negative_stream_ordering: bool = False, - inhibit_local_membership_updates: bool = False, + backfilled: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -142,14 +137,7 @@ async def _persist_events_and_state_updates( room state new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. - use_negative_stream_ordering: Whether to start stream_ordering on - the negative side and decrement. This should be set as True - for backfilled events because backfilled events get a negative - stream ordering so they don't come down incremental `/sync`. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled Returns: Resolves when the events have been persisted @@ -171,7 +159,7 @@ async def _persist_events_and_state_updates( # # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. - if use_negative_stream_ordering: + if backfilled: stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) @@ -188,13 +176,13 @@ async def _persist_events_and_state_updates( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(events_and_contexts)) - if stream < 0: + if not backfilled: # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( @@ -328,9 +316,8 @@ def _get_prevs_before_rejected_txn(txn, batch): def _persist_events_txn( self, txn: LoggingTransaction, - *, events_and_contexts: List[Tuple[EventBase, EventContext]], - inhibit_local_membership_updates: bool = False, + backfilled: bool, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): @@ -343,10 +330,7 @@ def _persist_events_txn( Args: txn events_and_contexts: events to persist - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled: True if the events were backfilled delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. @@ -379,7 +363,9 @@ def _persist_events_txn( events_and_contexts ) - self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) + self._update_room_depths_txn( + txn, events_and_contexts=events_and_contexts, backfilled=backfilled + ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -412,7 +398,7 @@ def _persist_events_txn( txn, events_and_contexts=events_and_contexts, all_events_and_contexts=all_events_and_contexts, - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # We call this last as it assumes we've inserted the events into @@ -575,9 +561,9 @@ def _add_chain_cover_index( # fetch their auth event info. while missing_auth_chains: sql = """ - SELECT event_id, events.type, se.state_key, chain_id, sequence_number + SELECT event_id, events.type, state_key, chain_id, sequence_number FROM events - INNER JOIN state_events AS se USING (event_id) + INNER JOIN state_events USING (event_id) LEFT JOIN event_auth_chains USING (event_id) WHERE """ @@ -1214,6 +1200,7 @@ def _update_room_depths_txn( self, txn, events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, ): """Update min_depth for each room @@ -1221,18 +1208,13 @@ def _update_room_depths_txn( txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting + backfilled (bool): True if the events were backfilled """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) - # Then update the `stream_ordering` position to mark the latest - # event as the front of the room. This should not be done for - # backfilled events because backfilled events have negative - # stream_ordering and happened in the past so we know that we don't - # need to update the stream_ordering tip/front for the room. - assert event.internal_metadata.stream_ordering is not None - if event.internal_metadata.stream_ordering >= 0: + if not backfilled: txn.call_after( self.store._events_stream_cache.entity_has_changed, event.room_id, @@ -1445,12 +1427,7 @@ def _store_rejected_events_txn(self, txn, events_and_contexts): return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _update_metadata_tables_txn( - self, - txn, - *, - events_and_contexts, - all_events_and_contexts, - inhibit_local_membership_updates: bool = False, + self, txn, events_and_contexts, all_events_and_contexts, backfilled ): """Update all the miscellaneous tables for new events @@ -1462,10 +1439,7 @@ def _update_metadata_tables_txn( events that we were going to persist. This includes events we've already persisted, etc, that wouldn't appear in events_and_context. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. + backfilled (bool): True if the events were backfilled """ # Insert all the push actions into the event_push_actions table. @@ -1539,7 +1513,7 @@ def _update_metadata_tables_txn( for event, _ in events_and_contexts if event.type == EventTypes.Member ], - inhibit_local_membership_updates=inhibit_local_membership_updates, + backfilled=backfilled, ) # Insert event_reference_hashes table. @@ -1579,13 +1553,11 @@ def _add_to_cache(self, txn, events_and_contexts): for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) + to_prefill.append(_EventCacheEntry(event=event, redacted_event=None)) def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.set( - (cache_entry.event.event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) @@ -1666,19 +1638,8 @@ def _store_event_reference_hashes_txn(self, txn, events): txn, table="event_reference_hashes", values=vals ) - def _store_room_members_txn( - self, txn, events, *, inhibit_local_membership_updates: bool = False - ): - """ - Store a room member in the database. - Args: - txn: The transaction to use. - events: List of events to store. - inhibit_local_membership_updates: Stop the local_current_membership - from being updated by these events. This should be set to True - for backfilled events because backfilled events in the past do - not affect the current local state. - """ + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database.""" def non_null_str_or_none(val: Any) -> Optional[str]: return val if isinstance(val, str) and "\u0000" not in val else None @@ -1721,7 +1682,7 @@ def non_null_str_or_none(val: Any) -> Optional[str]: # band membership", like a remote invite or a rejection of a remote invite. if ( self.is_mine_id(event.state_key) - and not inhibit_local_membership_updates + and not backfilled and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c7b660ac5a6f..c6bf316d5bf3 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -15,18 +15,14 @@ import logging import threading from typing import ( - TYPE_CHECKING, - Any, Collection, Container, Dict, Iterable, List, - NoReturn, Optional, Set, Tuple, - cast, overload, ) @@ -42,7 +38,6 @@ from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, - RoomVersion, RoomVersions, ) from synapse.events import EventBase, make_event_from_dict @@ -61,18 +56,10 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import ( - DatabasePool, - LoggingDatabaseConnection, - LoggingTransaction, -) +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - MultiWriterIdGenerator, - StreamIdGenerator, -) +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError @@ -82,13 +69,10 @@ from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure -if TYPE_CHECKING: - from synapse.server import HomeServer - logger = logging.getLogger(__name__) -# These values are used in the `enqueue_event` and `_fetch_loop` methods to +# These values are used in the `enqueus_event` and `_do_fetch` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster # on jki.re @@ -105,7 +89,7 @@ @attr.s(slots=True, auto_attribs=True) -class EventCacheEntry: +class _EventCacheEntry: event: EventBase redacted_event: Optional[EventBase] @@ -145,7 +129,7 @@ class _EventRow: json: str internal_metadata: str format_version: Optional[int] - room_version_id: Optional[str] + room_version_id: Optional[int] rejected_reason: Optional[str] redactions: List[str] outlier: bool @@ -169,16 +153,9 @@ class EventsWorkerStore(SQLBaseStore): # options controlling this. USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdTracker - self._backfill_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -237,7 +214,7 @@ def __init__( 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache = LruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -246,21 +223,19 @@ def __init__( # ID to cache entry. Note that the returned dict may not have the # requested event in it if the event isn't in the DB. self._current_event_fetches: Dict[ - str, ObservableDeferred[Dict[str, EventCacheEntry]] + str, ObservableDeferred[Dict[str, _EventCacheEntry]] ] = {} self._event_fetch_lock = threading.Condition() - self._event_fetch_list: List[ - Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] - ] = [] + self._event_fetch_list = [] self._event_fetch_ongoing = 0 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. - def get_chain_id_txn(txn: Cursor) -> int: + def get_chain_id_txn(txn): txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return cast(Tuple[int], txn.fetchone())[0] + return txn.fetchone()[0] self.event_chain_id_gen = build_sequence_generator( db_conn, @@ -271,13 +246,7 @@ def get_chain_id_txn(txn: Cursor) -> int: id_column="chain_id", ) - def process_replication_rows( - self, - stream_name: str, - instance_name: str, - token: int, - rows: Iterable[Any], - ) -> None: + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: @@ -311,10 +280,10 @@ async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[False] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, ) -> EventBase: ... @@ -323,10 +292,10 @@ async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = ..., - allow_rejected: bool = ..., - allow_none: Literal[True] = ..., - check_room_id: Optional[str] = ..., + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, ) -> Optional[EventBase]: ... @@ -388,7 +357,7 @@ async def get_event( async def get_events( self, - event_ids: Collection[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, @@ -575,7 +544,7 @@ async def get_events_as_list( async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -609,7 +578,7 @@ async def _get_events_from_cache_or_db( # same dict into itself N times). already_fetching_ids: Set[str] = set() already_fetching_deferreds: Set[ - ObservableDeferred[Dict[str, EventCacheEntry]] + ObservableDeferred[Dict[str, _EventCacheEntry]] ] = set() for event_id in missing_events_ids: @@ -632,8 +601,8 @@ async def _get_events_from_cache_or_db( # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ - Dict[str, EventCacheEntry] - ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -689,12 +658,12 @@ async def _get_events_from_cache_or_db( return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: + def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True - ) -> Dict[str, EventCacheEntry]: + ) -> Dict[str, _EventCacheEntry]: """Fetch events from the caches. May return rejected events. @@ -767,123 +736,38 @@ async def get_stripped_room_state_from_event_context( for e in state_to_include.values() ] - def _maybe_start_fetch_thread(self) -> None: - """Starts an event fetch thread if we are not yet at the maximum number.""" - with self._event_fetch_lock: - if ( - self._event_fetch_list - and self._event_fetch_ongoing < EVENT_QUEUE_THREADS - ): - self._event_fetch_ongoing += 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - # `_event_fetch_ongoing` is decremented in `_fetch_thread`. - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process("fetch_events", self._fetch_thread) - - async def _fetch_thread(self) -> None: - """Services requests for events from `_event_fetch_list`.""" - exc = None - try: - await self.db_pool.runWithConnection(self._fetch_loop) - except BaseException as e: - exc = e - raise - finally: - should_restart = False - event_fetches_to_fail = [] - with self._event_fetch_lock: - self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - - # There may still be work remaining in `_event_fetch_list` if we - # failed, or it was added in between us deciding to exit and - # decrementing `_event_fetch_ongoing`. - if self._event_fetch_list: - if exc is None: - # We decided to exit, but then some more work was added - # before `_event_fetch_ongoing` was decremented. - # If a new event fetch thread was not started, we should - # restart ourselves since the remaining event fetch threads - # may take a while to get around to the new work. - # - # Unfortunately it is not possible to tell whether a new - # event fetch thread was started, so we restart - # unconditionally. If we are unlucky, we will end up with - # an idle fetch thread, but it will time out after - # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds - # in any case. - # - # Note that multiple fetch threads may run down this path at - # the same time. - should_restart = True - elif isinstance(exc, Exception): - if self._event_fetch_ongoing == 0: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will - # handle them. - event_fetches_to_fail = self._event_fetch_list - self._event_fetch_list = [] - else: - # We weren't the last remaining fetcher, so another - # fetcher will pick up the work. This will either happen - # after their existing work, however long that takes, - # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if - # they are idle. - pass - else: - # The exception is a `SystemExit`, `KeyboardInterrupt` or - # `GeneratorExit`. Don't try to do anything clever here. - pass - - if should_restart: - # We exited cleanly but noticed more work. - self._maybe_start_fetch_thread() - - if event_fetches_to_fail: - # We were the last remaining fetcher and failed. - # Fail any outstanding fetches since no one else will handle them. - assert exc is not None - with PreserveLoggingContext(): - for _, deferred in event_fetches_to_fail: - deferred.errback(exc) - - def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: + def _do_fetch(self, conn: Connection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - # There are no requests waiting. If we haven't yet reached the - # maximum iteration limit, wait for some more requests to turn up. - # Otherwise, bail out. - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - return - - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 + try: + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + break + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 - self._fetch_event_list(conn, event_list) + self._fetch_event_list(conn, event_list) + finally: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) def _fetch_event_list( - self, - conn: LoggingDatabaseConnection, - event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], + self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] ) -> None: """Handle a load of requests from the _event_fetch_list queue @@ -910,7 +794,7 @@ def _fetch_event_list( ) # We only want to resolve deferreds from the main thread - def fire() -> None: + def fire(): for _, d in event_list: d.callback(row_dict) @@ -920,16 +804,18 @@ def fire() -> None: logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire_errback(exc: Exception) -> None: - for _, d in event_list: - d.errback(exc) + def fire(evs, exc): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire_errback, e) + self.hs.get_reactor().callFromThread(fire, event_list, e) async def _get_events_from_db( - self, event_ids: Collection[str] - ) -> Dict[str, EventCacheEntry]: + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. May return rejected events. @@ -945,29 +831,29 @@ async def _get_events_from_db( map from event id to result. May return extra events which weren't asked for. """ - fetched_event_ids: Set[str] = set() - fetched_events: Dict[str, _EventRow] = {} + fetched_events = {} events_to_fetch = event_ids while events_to_fetch: row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events - redaction_ids: Set[str] = set() + redaction_ids = set() for event_id in events_to_fetch: row = row_map.get(event_id) - fetched_event_ids.add(event_id) + fetched_events[event_id] = row if row: - fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_event_ids) + events_to_fetch = redaction_ids.difference(fetched_events.keys()) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) # build a map from event_id to EventBase - event_map: Dict[str, EventBase] = {} + event_map = {} for event_id, row in fetched_events.items(): + if not row: + continue assert row.event_id == event_id rejected_reason = row.rejected_reason @@ -995,7 +881,6 @@ async def _get_events_from_db( room_version_id = row.room_version_id - room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which # arrived before #6983 landed. For all other events, we should have @@ -1066,14 +951,14 @@ async def _get_events_from_db( # finally, we can decide whether each one needs redacting, and build # the cache entries. - result_map: Dict[str, EventCacheEntry] = {} + result_map = {} for event_id, original_ev in event_map.items(): redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) - cache_entry = EventCacheEntry( + cache_entry = _EventCacheEntry( event=original_ev, redacted_event=redacted_event ) @@ -1082,7 +967,7 @@ async def _get_events_from_db( return result_map - async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: + async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -1095,12 +980,23 @@ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow] that weren't requested. """ - events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() + events_d = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append((events, events_d)) + self._event_fetch_lock.notify() - self._maybe_start_fetch_thread() + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.db_pool.runWithConnection, self._do_fetch + ) logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): @@ -1250,7 +1146,7 @@ def _maybe_redact_event_row( # no valid redaction found for this event return None - async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ @@ -1279,7 +1175,7 @@ async def have_seen_events( event_ids: events we are looking for Returns: - The set of events we have already seen. + set[str]: The events we have already seen. """ res = await self._have_seen_events_dict( (room_id, event_id) for event_id in event_ids @@ -1302,9 +1198,7 @@ async def _have_seen_events_dict( } results = {x: True for x in cache_results} - def have_seen_events_txn( - txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] - ) -> None: + def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1330,14 +1224,12 @@ def have_seen_events_txn( return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: + async def have_seen_event(self, room_id: str, event_id: str): # this only exists for the benefit of the @cachedList descriptor on # _have_seen_events_dict raise NotImplementedError() - def _get_current_state_event_counts_txn( - self, txn: LoggingTransaction, room_id: str - ) -> int: + def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. """ @@ -1362,7 +1254,7 @@ async def get_current_state_event_counts(self, room_id: str) -> int: room_id, ) - async def get_room_complexity(self, room_id: str) -> Dict[str, float]: + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -1370,10 +1262,10 @@ async def get_room_complexity(self, room_id: str) -> Dict[str, float]: more resources. Args: - room_id: The room ID to query. + room_id (str) Returns: - dict[str:float] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1383,13 +1275,13 @@ async def get_room_complexity(self, room_id: str) -> Dict[str, float]: return {"v1": complexity_v1} - def get_current_events_token(self) -> int: + def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -1403,15 +1295,13 @@ async def get_all_new_forward_event_rows( EventsStreamRow. """ - def get_all_new_forward_event_rows( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1421,9 +1311,7 @@ def get_all_new_forward_event_rows( " LIMIT ?" ) txn.execute(sql, (last_id, current_id, instance_name, limit)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows @@ -1431,7 +1319,7 @@ def get_all_new_forward_event_rows( async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: @@ -1444,16 +1332,14 @@ async def get_ex_outlier_stream_rows( EventsStreamRow. """ - def get_ex_outlier_stream_rows_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1464,9 +1350,7 @@ def get_ex_outlier_stream_rows_txn( ) txn.execute(sql, (last_id, current_id, instance_name)) - return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() - ) + return txn.fetchall() return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn @@ -1474,7 +1358,7 @@ def get_ex_outlier_stream_rows_txn( async def get_all_new_backfill_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. @@ -1502,15 +1386,13 @@ async def get_all_new_backfill_event_rows( if last_id == current_id: return [], current_id, False - def get_all_new_backfill_event_rows( - txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: + def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " AND instance_name = ?" @@ -1518,15 +1400,7 @@ def get_all_new_backfill_event_rows( " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[ - Tuple[int, Tuple[str, str, str, str, str, str]] - ] = [] - row: Tuple[int, str, str, str, str, str, str] - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates = [(row[0], row[1:]) for row in txn] limited = False if len(new_event_updates) == limit: @@ -1537,11 +1411,11 @@ def get_all_new_backfill_event_rows( sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events AS se USING (event_id)" + " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" @@ -1549,11 +1423,7 @@ def get_all_new_backfill_event_rows( " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - # Type safety: iterating over `txn` yields `Tuple`, i.e. - # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a - # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] - new_event_updates.append((row[0], row[1:])) + new_event_updates.extend((row[0], row[1:]) for row in txn) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] @@ -1567,7 +1437,7 @@ def get_all_new_backfill_event_rows( async def get_all_updated_current_state_deltas( self, instance_name: str, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: + ) -> Tuple[List[Tuple], int, bool]: """Fetch updates from current_state_delta_stream Args: @@ -1587,9 +1457,7 @@ async def get_all_updated_current_state_deltas( * `limited` is whether there are more updates to fetch. """ - def get_all_updated_current_state_deltas_txn( - txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str]]: + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream @@ -1598,23 +1466,21 @@ def get_all_updated_current_state_deltas_txn( ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() - def get_deltas_for_stream_id_txn( - txn: LoggingTransaction, stream_id: int - ) -> List[Tuple[int, str, str, str, str]]: + def get_deltas_for_stream_id_txn(txn, stream_id): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE stream_id = ? """ txn.execute(sql, [stream_id]) - return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) + return txn.fetchall() # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( + rows: List[Tuple] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1643,14 +1509,14 @@ def get_deltas_for_stream_id_txn( return rows, to_token, True - async def is_event_after(self, event_id1: str, event_id2: str) -> bool: + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cached(max_entries=5000) - async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: + async def get_event_ordering(self, event_id): res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], @@ -1673,9 +1539,7 @@ async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: None otherwise. """ - def get_next_event_to_expire_txn( - txn: LoggingTransaction, - ) -> Optional[Tuple[str, int]]: + def get_next_event_to_expire_txn(txn): txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -1683,7 +1547,7 @@ def get_next_event_to_expire_txn( """ ) - return cast(Optional[Tuple[str, int]], txn.fetchone()) + return txn.fetchone() return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -1747,10 +1611,10 @@ async def get_already_persisted_events( return mapping @wrap_as_background_process("_cleanup_old_transaction_ids") - async def _cleanup_old_transaction_ids(self) -> None: + async def _cleanup_old_transaction_ids(self): """Cleans out transaction id mappings older than 24hrs.""" - def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: + def _cleanup_old_transaction_ids_txn(txn): sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? @@ -1762,198 +1626,3 @@ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, ) - - async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a backward gap of missing events. - A(False)--->B(False)--->C(True)---> - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question has any of its prev_events listed as a - # backward extremity, it's next to a gap. - # - # We can't just check the backward edges in `event_edges` because - # when we persist events, we will also record the prev_events as - # edges to the event in question regardless of whether we have those - # prev_events yet. We need to check whether those prev_events are - # backward extremities, also known as gaps, that need to be - # backfilled. - backward_extremity_query = """ - SELECT 1 FROM event_backward_extremities - WHERE - room_id = ? - AND %s - LIMIT 1 - """ - - # If the event in question is a backward extremity or has any of its - # prev_events listed as a backward extremity, it's next to a - # backward gap. - clause, args = make_in_list_sql_clause( - self.database_engine, - "event_id", - [event.event_id] + list(event.prev_event_ids()), - ) - - txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) - backward_extremities = txn.fetchall() - - # We consider any backward extremity as a backward gap - if len(backward_extremities): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_backward_gap_txn", - is_event_next_to_backward_gap_txn, - ) - - async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: - """Check if the given event is next to a forward gap of missing events. - The gap in front of the latest events is not considered a gap. - A(False)--->B(False)--->C(False)---> - A(False)--->B(False)---> --->D(True)--->E(False) - - Args: - room_id: room where the event lives - event_id: event to check - - Returns: - Boolean indicating whether it's an extremity - """ - - def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: - # If the event in question is a forward extremity, we will just - # consider any potential forward gap as not a gap since it's one of - # the latest events in the room. - # - # `event_forward_extremities` does not include backfilled or outlier - # events so we can't rely on it to find forward gaps. We can only - # use it to determine whether a message is the latest in the room. - # - # We can't combine this query with the `forward_edge_query` below - # because if the event in question has no forward edges (isn't - # referenced by any other event's prev_events) but is in - # `event_forward_extremities`, we don't want to return 0 rows and - # say it's next to a gap. - forward_extremity_query = """ - SELECT 1 FROM event_forward_extremities - WHERE - room_id = ? - AND event_id = ? - LIMIT 1 - """ - - # Check to see whether the event in question is already referenced - # by another event. If we don't see any edges, we're next to a - # forward gap. - forward_edge_query = """ - SELECT 1 FROM event_edges - /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id - WHERE - event_edges.room_id = ? - AND event_edges.prev_event_id = ? - /* It's not a valid edge if the event referencing our event in - * question is rejected. - */ - AND rejections.event_id IS NULL - LIMIT 1 - """ - - # We consider any forward extremity as the latest in the room and - # not a forward gap. - # - # To expand, even though there is technically a gap at the front of - # the room where the forward extremities are, we consider those the - # latest messages in the room so asking other homeservers for more - # is useless. The new latest messages will just be federated as - # usual. - txn.execute(forward_extremity_query, (event.room_id, event.event_id)) - forward_extremities = txn.fetchall() - if len(forward_extremities): - return False - - # If there are no forward edges to the event in question (another - # event hasn't referenced this event in their prev_events), then we - # assume there is a forward gap in the history. - txn.execute(forward_edge_query, (event.room_id, event.event_id)) - forward_edges = txn.fetchall() - if not len(forward_edges): - return True - - return False - - return await self.db_pool.runInteraction( - "is_event_next_to_gap_txn", - is_event_next_to_gap_txn, - ) - - async def get_event_id_for_timestamp( - self, room_id: str, timestamp: int, direction: str - ) -> Optional[str]: - """Find the closest event to the given timestamp in the given direction. - - Args: - room_id: Room to fetch the event from - timestamp: The point in time (inclusive) we should navigate from in - the given direction to find the closest event. - direction: ["f"|"b"] to indicate whether we should navigate forward - or backward from the given timestamp to find the closest event. - - Returns: - The closest event_id otherwise None if we can't find any event in - the given direction. - """ - - sql_template = """ - SELECT event_id FROM events - LEFT JOIN rejections USING (event_id) - WHERE - origin_server_ts %s ? - AND room_id = ? - /* Make sure event is not rejected */ - AND rejections.event_id IS NULL - ORDER BY origin_server_ts %s - LIMIT 1; - """ - - def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - - txn.execute( - sql_template % (comparison_operator, order), (timestamp, room_id) - ) - row = txn.fetchone() - if row: - (event_id,) = row - return event_id - - return None - - if direction not in ("f", "b"): - raise ValueError("Unknown direction: %s" % (direction,)) - - return await self.db_pool.runInteraction( - "get_event_id_for_timestamp_txn", - get_event_id_for_timestamp_txn, - ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 91b0576b8568..3eb30944bf97 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -118,7 +118,7 @@ def _purge_history_txn( logger.info("[purge] looking for events to delete") - should_delete_expr = "state_events.state_key IS NULL" + should_delete_expr = "state_key IS NULL" should_delete_params: Tuple[Any, ...] = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 3b63267395c3..fa782023d4ee 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -28,10 +28,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -85,9 +82,9 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) + self._push_rules_stream_id_gen: Union[ + StreamIdGenerator, SlavedIdTracker + ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e1ddf0691646..0e8c16866760 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -106,15 +106,6 @@ class RefreshTokenLookupResult: has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" - expiry_ts: Optional[int] - """The time at which the refresh token expires and can not be used. - If None, the refresh token doesn't expire.""" - - ultimate_session_expiry_ts: Optional[int] - """The time at which the session comes to an end and can no longer be - refreshed. - If None, the session can be refreshed indefinitely.""" - class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( @@ -1635,10 +1626,8 @@ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: rt.user_id, rt.device_id, rt.next_token_id, - (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed, - at.used AS has_next_access_token_been_used, - rt.expiry_ts, - rt.ultimate_session_expiry_ts + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used FROM refresh_tokens rt LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id @@ -1659,8 +1648,6 @@ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: has_next_refresh_token_been_refreshed=row[4], # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), - expiry_ts=row[6], - ultimate_session_expiry_ts=row[7], ) return await self.db_pool.runInteraction( @@ -1928,8 +1915,6 @@ async def add_refresh_token_to_user( user_id: str, token: str, device_id: Optional[str], - expiry_ts: Optional[int], - ultimate_session_expiry_ts: Optional[int], ) -> int: """Adds a refresh token for the given user. @@ -1937,13 +1922,6 @@ async def add_refresh_token_to_user( user_id: The user ID. token: The new access token to add. device_id: ID of the device to associate with the refresh token. - expiry_ts (milliseconds since the epoch): Time after which the - refresh token cannot be used. - If None, the refresh token never expires until it has been used. - ultimate_session_expiry_ts (milliseconds since the epoch): - Time at which the session will end and can not be extended any - further. - If None, the session can be refreshed indefinitely. Raises: StoreError if there was a problem adding this. Returns: @@ -1959,8 +1937,6 @@ async def add_refresh_token_to_user( "device_id": device_id, "token": token, "next_token_id": None, - "expiry_ts": expiry_ts, - "ultimate_session_expiry_ts": ultimate_session_expiry_ts, }, desc="add_refresh_token_to_user", ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 6b2a8d06a67c..033a9831d664 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -476,7 +476,7 @@ def _get_rooms_for_user_with_stream_ordering_txn( INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND c.membership = ? """ else: @@ -487,7 +487,7 @@ def _get_rooms_for_user_with_stream_ordering_txn( INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND c.state_key = ? + AND state_key = ? AND m.membership = ? """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 57aab5525937..42dc807d17ff 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -497,7 +497,7 @@ async def get_room_events_stream_for_room( oldest `limit` events. Returns: - The list of events (in ascending stream order) and the token from the start + The list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: @@ -510,7 +510,7 @@ async def get_room_events_stream_for_room( if not has_changed: return [], from_key - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -565,13 +565,6 @@ def f(txn: LoggingTransaction) -> List[_EventDictReturn]: async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - """Fetch membership events for a given user. - - All such events whose stream ordering `s` lies in the range - `from_key < s <= to_key` are returned. Events are ordered by ascending stream - order. - """ - # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -582,7 +575,7 @@ async def get_membership_changes_for_user( if not has_changed: return [] - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn): # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -641,7 +634,7 @@ async def get_recent_events_for_room( Returns: A list of events and a token pointing to the start of the returned - events. The events returned are in ascending topological order. + events. The events returned are in ascending order. """ rows, token = await self.get_recent_event_ids_for_room( diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 162282255232..d7dc1f73ac16 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,6 @@ import logging from collections import namedtuple -from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr @@ -45,16 +44,6 @@ ) -class DestinationSortOrder(Enum): - """Enum to define the sorting method used when returning destinations.""" - - DESTINATION = "destination" - RETRY_LAST_TS = "retry_last_ts" - RETTRY_INTERVAL = "retry_interval" - FAILURE_TS = "failure_ts" - LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering" - - @attr.s(slots=True, frozen=True, auto_attribs=True) class DestinationRetryTimings: """The current destination retry timing info for a remote server.""" @@ -491,62 +480,3 @@ def _get_catch_up_outstanding_destinations_txn( destinations = [row[0] for row in txn] return destinations - - async def get_destinations_paginate( - self, - start: int, - limit: int, - destination: Optional[str] = None, - order_by: str = DestinationSortOrder.DESTINATION.value, - direction: str = "f", - ) -> Tuple[List[JsonDict], int]: - """Function to retrieve a paginated list of destinations. - This will return a json list of destinations and the - total number of destinations matching the filter criteria. - - Args: - start: start number to begin the query from - limit: number of rows to retrieve - destination: search string in destination - order_by: the sort order of the returned list - direction: sort ascending or descending - Returns: - A tuple of a list of mappings from destination to information - and a count of total destinations. - """ - - def get_destinations_paginate_txn( - txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: - order_by_column = DestinationSortOrder(order_by).value - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - args = [] - where_statement = "" - if destination: - args.extend(["%" + destination.lower() + "%"]) - where_statement = "WHERE LOWER(destination) LIKE ?" - - sql_base = f"FROM destinations {where_statement} " - sql = f"SELECT COUNT(*) as total_destinations {sql_base}" - txn.execute(sql, args) - count = txn.fetchone()[0] - - sql = f""" - SELECT destination, retry_last_ts, retry_interval, failure_ts, - last_successful_stream_ordering - {sql_base} - ORDER BY {order_by_column} {order}, destination ASC - LIMIT ? OFFSET ? - """ - txn.execute(sql, args + [limit, start]) - destinations = self.db_pool.cursor_to_dict(txn) - return destinations, count - - return await self.db_pool.runInteraction( - "get_destinations_paginate_txn", get_destinations_paginate_txn - ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 428d66a617b1..402f134d894b 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -583,8 +583,7 @@ async def _persist_event_batch( current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, - use_negative_stream_ordering=backfilled, - inhibit_local_membership_updates=backfilled, + backfilled=backfilled, ) await self._handle_potentially_left_users(potentially_left_users) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 50d08094d52c..3a00ed683545 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 66 # remember to update the list below when updating +SCHEMA_VERSION = 65 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -46,10 +46,6 @@ - MSC2716: Remove unique event_id constraint from insertion_event_edges because an insertion event can have multiple edges. - Remove unused tables `user_stats_historical` and `room_stats_historical`. - -Changes in SCHEMA_VERSION = 66: - - Queries on state_key columns are now disambiguated (ie, the codebase can handle - the `events` table having a `state_key` column). """ diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql deleted file mode 100644 index bdc491c8174b..000000000000 --- a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2021 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -ALTER TABLE refresh_tokens - -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens. - -- They may not be used after they have expired. - -- If null, then the refresh token's lifetime is unlimited. - ADD COLUMN expiry_ts BIGINT DEFAULT NULL; - -ALTER TABLE refresh_tokens - -- We also add an ultimate session expiry time (in milliseconds since the Epoch). - -- No matter how much the access and refresh tokens are refreshed, they cannot - -- be extended past this time. - -- If null, then the session length is unlimited. - ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL; diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql deleted file mode 100644 index a65bfb520d88..000000000000 --- a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2021 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Track the auth provider used by each login as well as the session ID -CREATE TABLE device_auth_providers ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - auth_provider_id TEXT NOT NULL, - auth_provider_session_id TEXT NOT NULL -); - -CREATE INDEX device_auth_providers_devices - ON device_auth_providers (user_id, device_id); -CREATE INDEX device_auth_providers_sessions - ON device_auth_providers (auth_provider_id, auth_provider_session_id); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4ff3013908a7..ac56bc9a050f 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -89,77 +89,31 @@ def _load_current_id( return (max if step > 0 else min)(current_id, step) -class AbstractStreamIdTracker(metaclass=abc.ABCMeta): - """Tracks the "current" stream ID of a stream that may have multiple writers. - - Stream IDs are monotonically increasing or decreasing integers representing write - transactions. The "current" stream ID is the stream ID such that all transactions - with equal or smaller stream IDs have completed. Since transactions may complete out - of order, this is not the same as the stream ID of the last completed transaction. - - Completed transactions include both committed transactions and transactions that - have been rolled back. - """ - - @abc.abstractmethod - def advance(self, instance_name: str, new_id: int) -> None: - """Advance the position of the named writer to the given ID, if greater - than existing entry. - """ - raise NotImplementedError() - +class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): @abc.abstractmethod - def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - - Returns: - The maximum stream id. - """ + def get_next(self) -> AsyncContextManager[int]: raise NotImplementedError() @abc.abstractmethod - def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to `get_current_token`. - """ + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: raise NotImplementedError() - -class AbstractStreamIdGenerator(AbstractStreamIdTracker): - """Generates stream IDs for a stream that may have multiple writers. - - Each stream ID represents a write transaction, whose completion is tracked - so that the "current" stream ID of the stream can be determined. - - See `AbstractStreamIdTracker` for more details. - """ - @abc.abstractmethod - def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ + def get_current_token(self) -> int: raise NotImplementedError() @abc.abstractmethod - def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: - """ - Usage: - async with stream_id_gen.get_next(n) as stream_ids: - # ... persist events ... - """ + def get_current_token_for_writer(self, instance_name: str) -> int: raise NotImplementedError() class StreamIdGenerator(AbstractStreamIdGenerator): - """Generates and tracks stream IDs for a stream with a single writer. + """Used to generate new stream ids when persisting events while keeping + track of which transactions have been completed. - This class must only be used when the current Synapse process is the sole - writer for a stream. + This allows us to get the "current" stream id, i.e. the stream id such that + all ids less than or equal to it have completed. This handles the fact that + persistence of events can complete out of order. Args: db_conn(connection): A database connection to use to fetch the @@ -203,12 +157,12 @@ def __init__( # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() - def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") - def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ with self._lock: self._current += self._step next_id = self._current @@ -226,6 +180,11 @@ def manager() -> Generator[int, None, None]: return _AsyncCtxManagerWrapper(manager()) def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + """ + Usage: + async with stream_id_gen.get_next(n) as stream_ids: + # ... persist events ... + """ with self._lock: next_ids = range( self._current + self._step, @@ -249,6 +208,12 @@ def manager() -> Generator[Sequence[int], None, None]: return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + + Returns: + The maximum stream id. + """ with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step @@ -256,11 +221,16 @@ def get_current_token(self) -> int: return self._current def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ return self.get_current_token() class MultiWriterIdGenerator(AbstractStreamIdGenerator): - """Generates and tracks stream IDs for a stream with multiple writers. + """An ID generator that tracks a stream that can have multiple writers. Uses a Postgres sequence to coordinate ID assignment, but positions of other writers will only get updated when `advance` is called (by replication). @@ -505,6 +475,12 @@ def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: return stream_ids def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -516,6 +492,12 @@ def get_next(self) -> AsyncContextManager[int]: return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: + """ + Usage: + async with stream_id_gen.get_next_mult(5) as stream_ids: + # ... persist events ... + """ + # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -615,9 +597,15 @@ def _mark_id_as_finished(self, next_id: int) -> None: self._add_persisted_position(next_id) def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer.""" + # If we don't have an entry for the given instance name, we assume it's a # new writer. # @@ -643,6 +631,10 @@ def get_positions(self) -> Dict[str, int]: } def advance(self, instance_name: str, new_id: int) -> None: + """Advance the position of the named writer to the given ID, if greater + than existing entry. + """ + new_id *= self._return_factor with self._lock: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 377c9a282a69..3c4cc093aff3 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -22,7 +22,6 @@ Iterable, MutableMapping, Optional, - Sized, TypeVar, Union, cast, @@ -105,13 +104,7 @@ def metrics_cb() -> None: max_size=max_entries, cache_name=name, cache_type=cache_type, - size_callback=( - (lambda d: len(cast(Sized, d)) or 1) - # Argument 1 to "len" has incompatible type "VT"; expected "Sized" - # We trust that `VT` is `Sized` when `iterable` is `True` - if iterable - else None - ), + size_callback=(lambda d: len(d) or 1) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, prune_unread_entries=prune_unread_entries, diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index eb96f7e665e6..a0a7a9de3299 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -15,15 +15,14 @@ import logging import threading import weakref -from enum import Enum from functools import wraps from typing import ( TYPE_CHECKING, Any, Callable, Collection, - Dict, Generic, + Iterable, List, Optional, Type, @@ -191,7 +190,7 @@ def __init__( root: "ListNode[_Node]", key: KT, value: VT, - cache: "weakref.ReferenceType[LruCache[KT, VT]]", + cache: "weakref.ReferenceType[LruCache]", clock: Clock, callbacks: Collection[Callable[[], None]] = (), prune_unread_entries: bool = True, @@ -271,10 +270,7 @@ def drop_from_cache(self) -> None: removed from all lists. """ cache = self._cache() - if ( - cache is None - or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel - ): + if not cache or not cache.pop(self.key, None): # `cache.pop` should call `drop_from_lists()`, unless this Node had # already been removed from the cache. self.drop_from_lists() @@ -294,12 +290,6 @@ def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None: self._global_list_node.update_last_access(clock) -class _Sentinel(Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup. - sentinel = object() - - class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. @@ -312,7 +302,7 @@ def __init__( max_size: int, cache_name: Optional[str] = None, cache_type: Type[Union[dict, TreeCache]] = dict, - size_callback: Optional[Callable[[VT], int]] = None, + size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, clock: Optional[Clock] = None, @@ -349,7 +339,7 @@ def __init__( else: real_clock = clock - cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type() + cache = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -384,7 +374,7 @@ def __init__( # creating more each time we create a `_Node`. weak_ref_to_self = weakref.ref(self) - list_root = ListNode[_Node[KT, VT]].create_root_node() + list_root = ListNode[_Node].create_root_node() lock = threading.Lock() @@ -432,7 +422,7 @@ def cache_len() -> int: def add_node( key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () ) -> None: - node: _Node[KT, VT] = _Node( + node = _Node( list_root, key, value, @@ -449,10 +439,10 @@ def add_node( if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node: _Node[KT, VT]) -> None: + def move_node_to_front(node: _Node) -> None: node.move_to_front(real_clock, list_root) - def delete_node(node: _Node[KT, VT]) -> int: + def delete_node(node: _Node) -> int: node.drop_from_lists() deleted_len = 1 @@ -506,7 +496,7 @@ def cache_get( @synchronized def cache_set( - key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () + key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () ) -> None: node = cache.get(key, None) if node is not None: @@ -600,6 +590,8 @@ def cache_clear() -> None: def cache_contains(key: KT) -> bool: return key in cache + self.sentinel = object() + # make sure that we clear out any excess entries after we get resized. self._on_resize = evict @@ -616,18 +608,18 @@ def cache_contains(key: KT) -> bool: self.clear = cache_clear def __getitem__(self, key: KT) -> VT: - result = self.get(key, _Sentinel.sentinel) - if result is _Sentinel.sentinel: + result = self.get(key, self.sentinel) + if result is self.sentinel: raise KeyError() else: - return result + return cast(VT, result) def __setitem__(self, key: KT, value: VT) -> None: self.set(key, value) def __delitem__(self, key: KT, value: VT) -> None: - result = self.pop(key, _Sentinel.sentinel) - if result is _Sentinel.sentinel: + result = self.pop(key, self.sentinel) + if result is self.sentinel: raise KeyError() def __len__(self) -> int: diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py index 8efbf061aaae..9f4be757baa5 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py @@ -84,7 +84,7 @@ def remove_from_list(self) -> None: # immediately rather than at the next GC. self.cache_entry = None - def move_after(self, node: "ListNode[P]") -> None: + def move_after(self, node: "ListNode") -> None: """Move this node from its current location in the list to after the given node. """ @@ -122,7 +122,7 @@ def _refs_remove_node_from_list(self) -> None: self.prev_node = None self.next_node = None - def _refs_insert_after(self, node: "ListNode[P]") -> None: + def _refs_insert_after(self, node: "ListNode") -> None: """Internal method to insert the node after the given node.""" # This method should only be called when we're not already in the list. diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index c144ff62c1fa..899ee0adc803 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -1,5 +1,4 @@ # Copyright 2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,11 +29,10 @@ def get_version_string(module: ModuleType) -> str: If called on a module not in a git checkout will return `__version__`. Args: - module: The module to check the version of. Must declare a __version__ - attribute. + module (module) Returns: - The module version (as a string). + str """ cached_version = version_cache.get(module) @@ -46,37 +44,71 @@ def get_version_string(module: ModuleType) -> str: version_string = module.__version__ # type: ignore[attr-defined] try: + null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) - def _run_git_command(prefix: str, *params: str) -> str: - try: - result = ( - subprocess.check_output( - ["git", *params], stderr=subprocess.DEVNULL, cwd=cwd - ) - .strip() - .decode("ascii") + try: + git_branch = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd ) - return prefix + result - except (subprocess.CalledProcessError, FileNotFoundError): - return "" - - git_branch = _run_git_command("b=", "rev-parse", "--abbrev-ref", "HEAD") - git_tag = _run_git_command("t=", "describe", "--exact-match") - git_commit = _run_git_command("", "rev-parse", "--short", "HEAD") + .strip() + .decode("ascii") + ) + git_branch = "b=" + git_branch + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError can arise when git is not installed + git_branch = "" + + try: + git_tag = ( + subprocess.check_output( + ["git", "describe", "--exact-match"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) + git_tag = "t=" + git_tag + except (subprocess.CalledProcessError, FileNotFoundError): + git_tag = "" + + try: + git_commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) + except (subprocess.CalledProcessError, FileNotFoundError): + git_commit = "" + + try: + dirty_string = "-this_is_a_dirty_checkout" + is_dirty = ( + subprocess.check_output( + ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + .endswith(dirty_string) + ) - dirty_string = "-this_is_a_dirty_checkout" - is_dirty = _run_git_command("", "describe", "--dirty=" + dirty_string).endswith( - dirty_string - ) - git_dirty = "dirty" if is_dirty else "" + git_dirty = "dirty" if is_dirty else "" + except (subprocess.CalledProcessError, FileNotFoundError): + git_dirty = "" if git_branch or git_tag or git_commit or git_dirty: git_version = ",".join( s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - version_string = f"{version_string} ({git_version})" + version_string = "%s (%s)" % ( + # If the __version__ attribute doesn't exist, we'll have failed + # loudly above. + module.__version__, # type: ignore[attr-defined] + git_version, + ) except Exception as e: logger.info("Failed to check for git repository: %s", e) diff --git a/synctl b/synctl index 0e54f4847bbd..90559ded62e6 100755 --- a/synctl +++ b/synctl @@ -24,7 +24,7 @@ import signal import subprocess import sys import time -from typing import Iterable, Optional +from typing import Iterable import yaml @@ -41,24 +41,11 @@ NORMAL = "\x1b[m" def pid_running(pid): try: os.kill(pid, 0) + return True except OSError as err: if err.errno == errno.EPERM: - pass # process exists - else: - return False - - # When running in a container, orphan processes may not get reaped and their - # PIDs may remain valid. Try to work around the issue. - try: - with open(f"/proc/{pid}/status") as status_file: - if "zombie" in status_file.read(): - return False - except Exception: - # This isn't Linux or `/proc/` is unavailable. - # Assume that the process is still running. - pass - - return True + return True + return False def write(message, colour=NORMAL, stream=sys.stdout): @@ -122,14 +109,15 @@ def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool) return False -def stop(pidfile: str, app: str) -> Optional[int]: +def stop(pidfile: str, app: str) -> bool: """Attempts to kill a synapse worker from the pidfile. Args: pidfile: path to file containing worker's pid app: name of the worker's appservice Returns: - process id, or None if the process was not running + True if the process stopped successfully + False if process was already stopped or an error occured """ if os.path.exists(pidfile): @@ -137,7 +125,7 @@ def stop(pidfile: str, app: str) -> Optional[int]: try: os.kill(pid, signal.SIGTERM) write("stopped %s" % (app,), colour=GREEN) - return pid + return True except OSError as err: if err.errno == errno.ESRCH: write("%s not running" % (app,), colour=YELLOW) @@ -145,13 +133,14 @@ def stop(pidfile: str, app: str) -> Optional[int]: abort("Cannot stop %s: Operation not permitted" % (app,)) else: abort("Cannot stop %s: Unknown error" % (app,)) + return False else: write( "No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)" % (app, pidfile), colour=YELLOW, ) - return None + return False Worker = collections.namedtuple( @@ -299,23 +288,32 @@ def main(): action = options.action if action == "stop" or action == "restart": - running_pids = [] + has_stopped = True for worker in workers: - pid = stop(worker.pidfile, worker.app) - if pid is not None: - running_pids.append(pid) + if not stop(worker.pidfile, worker.app): + # A worker could not be stopped. + has_stopped = False if start_stop_synapse: - pid = stop(pidfile, MAIN_PROCESS) - if pid is not None: - running_pids.append(pid) + if not stop(pidfile, MAIN_PROCESS): + has_stopped = False + if not has_stopped and action == "stop": + sys.exit(1) + # Wait for synapse to actually shutdown before starting it again + if action == "restart": + running_pids = [] + if start_stop_synapse and os.path.exists(pidfile): + running_pids.append(int(open(pidfile).read())) + for worker in workers: + if os.path.exists(worker.pidfile): + running_pids.append(int(open(worker.pidfile).read())) if len(running_pids) > 0: - write("Waiting for processes to exit...") + write("Waiting for process to exit before restarting...") for running_pid in running_pids: while pid_running(running_pid): time.sleep(0.2) - write("All processes exited") + write("All processes exited; now restarting...") if action == "start" or action == "restart": error = False diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py deleted file mode 100644 index cbcada04517e..000000000000 --- a/tests/app/test_homeserver_start.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import synapse.app.homeserver -from synapse.config._base import ConfigError - -from tests.config.utils import ConfigFileTestCase - - -class HomeserverAppStartTestCase(ConfigFileTestCase): - def test_wrong_start_caught(self): - # Generate a config with a worker_app - self.generate_config() - # Add a blank line as otherwise the next addition ends up on a line with a comment - self.add_lines_to_config([" "]) - self.add_lines_to_config(["worker_app: test_worker_app"]) - - # Ensure that starting master process with worker config raises an exception - with self.assertRaises(ConfigError): - synapse.app.homeserver.setup(["-c", self.config_file]) diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py deleted file mode 100644 index 17a84d20d811..000000000000 --- a/tests/config/test_registration_config.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.config import ConfigError -from synapse.config.homeserver import HomeServerConfig - -from tests.unittest import TestCase -from tests.utils import default_config - - -class RegistrationConfigTestCase(TestCase): - def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): - """ - session_lifetime should logically be larger than, or at least as large as, - all the different token lifetimes. - Test that the user is faced with configuration errors if they make it - smaller, as that configuration doesn't make sense. - """ - config_dict = default_config("test") - - # First test all the error conditions - with self.assertRaises(ConfigError): - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "30m", - "nonrefreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - with self.assertRaises(ConfigError): - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "30m", - "refreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - with self.assertRaises(ConfigError): - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "30m", - "refresh_token_lifetime": "31m", - **config_dict, - } - ) - - # Then test all the fine conditions - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "31m", - "nonrefreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - HomeServerConfig().parse_config_dict( - { - "session_lifetime": "31m", - "refreshable_access_token_lifetime": "31m", - **config_dict, - } - ) - - HomeServerConfig().parse_config_dict( - {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} - ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 17a9fb63a176..4d1e154578c1 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -22,7 +22,6 @@ from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key -from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError @@ -578,76 +577,6 @@ def test_get_keys_from_perspectives(self): bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) - def test_get_multiple_keys_from_perspectives(self): - """Check that we can correctly request multiple keys for the same server""" - - fetcher = PerspectivesKeyFetcher(self.hs) - - SERVER_NAME = "server2" - - testkey1 = signedjson.key.generate_signing_key("ver1") - testverifykey1 = signedjson.key.get_verify_key(testkey1) - testverifykey1_id = "ed25519:ver1" - - testkey2 = signedjson.key.generate_signing_key("ver2") - testverifykey2 = signedjson.key.get_verify_key(testkey2) - testverifykey2_id = "ed25519:ver2" - - VALID_UNTIL_TS = 200 * 1000 - - response1 = self.build_perspectives_response( - SERVER_NAME, - testkey1, - VALID_UNTIL_TS, - ) - response2 = self.build_perspectives_response( - SERVER_NAME, - testkey2, - VALID_UNTIL_TS, - ) - - async def post_json(destination, path, data, **kwargs): - self.assertEqual(destination, self.mock_perspective_server.server_name) - self.assertEqual(path, "/_matrix/key/v2/query") - - # check that the request is for the expected keys - q = data["server_keys"] - - self.assertEqual( - list(q[SERVER_NAME].keys()), [testverifykey1_id, testverifykey2_id] - ) - return {"server_keys": [response1, response2]} - - self.http_client.post_json.side_effect = post_json - - # fire off two separate requests; they should get merged together into a - # single HTTP hit. - request1_d = defer.ensureDeferred( - fetcher.get_keys(SERVER_NAME, [testverifykey1_id], 0) - ) - request2_d = defer.ensureDeferred( - fetcher.get_keys(SERVER_NAME, [testverifykey2_id], 0) - ) - - keys1 = self.get_success(request1_d) - self.assertIn(testverifykey1_id, keys1) - k = keys1[testverifykey1_id] - self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual(k.verify_key, testverifykey1) - self.assertEqual(k.verify_key.alg, "ed25519") - self.assertEqual(k.verify_key.version, "ver1") - - keys2 = self.get_success(request2_d) - self.assertIn(testverifykey2_id, keys2) - k = keys2[testverifykey2_id] - self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual(k.verify_key, testverifykey2) - self.assertEqual(k.verify_key.alg, "ed25519") - self.assertEqual(k.verify_key.version, "ver2") - - # finally, ensure that only one request was sent - self.assertEqual(self.http_client.post_json.call_count, 1) - def test_get_perspectives_own_key(self): """Check that we can get the perspectives server's own keys diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py deleted file mode 100644 index a7031a55f28c..000000000000 --- a/tests/federation/transport/test_client.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.api.room_versions import RoomVersions -from synapse.federation.transport.client import SendJoinParser - -from tests.unittest import TestCase - - -class SendJoinParserTestCase(TestCase): - def test_two_writes(self) -> None: - """Test that the parser can sensibly deserialise an input given in two slices.""" - parser = SendJoinParser(RoomVersions.V1, True) - parent_event = { - "content": { - "see_room_version_spec": "The event format changes depending on the room version." - }, - "event_id": "$authparent", - "room_id": "!somewhere:example.org", - "type": "m.room.minimal_pdu", - } - state = { - "content": { - "see_room_version_spec": "The event format changes depending on the room version." - }, - "event_id": "$DoNotThinkAboutTheEvent", - "room_id": "!somewhere:example.org", - "type": "m.room.minimal_pdu", - } - response = [ - 200, - { - "auth_chain": [parent_event], - "origin": "matrix.org", - "state": [state], - }, - ] - serialised_response = json.dumps(response).encode() - - # Send data to the parser - parser.write(serialised_response[:100]) - parser.write(serialised_response[100:]) - - # Retrieve the parsed SendJoinResponse - parsed_response = parser.finish() - - # Sanity check the parsing gave us sensible data. - self.assertEqual(len(parsed_response.auth_events), 1, parsed_response) - self.assertEqual(len(parsed_response.state), 1, parsed_response) - self.assertEqual(parsed_response.event_dict, {}, parsed_response) - self.assertIsNone(parsed_response.event, parsed_response) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 03b8b8615c62..72e176da7543 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -71,7 +71,7 @@ def verify_guest(caveat): def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.user1, "", 5000 ) res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) self.assertEqual(self.user1, res.user_id) @@ -94,7 +94,7 @@ def test_short_term_login_token_gives_auth_provider(self): def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.user1, "", 5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) @@ -213,6 +213,6 @@ def test_mau_limits_not_exceeded(self): def _get_macaroon(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.user1, "", 5000 ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 8705ff894343..b625995d1253 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,13 +66,7 @@ def test_map_cas_user_to_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=True ) def test_map_cas_user_to_existing_user(self): @@ -95,13 +89,7 @@ def test_map_cas_user_to_existing_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -110,13 +98,7 @@ def test_map_cas_user_to_existing_user(self): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=False ) def test_map_cas_user_to_invalid_localpart(self): @@ -134,13 +116,7 @@ def test_map_cas_user_to_invalid_localpart(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True ) @override_config( @@ -184,13 +160,7 @@ def test_required_attributes(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "cas", request, "redirect_uri", None, new_user=True ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index cfe3de526682..a25c89bd5bd3 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -252,6 +252,13 @@ async def patched_load_metadata(): with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) + # Return empty key set if JWKS are not used + self.provider._scopes = [] # not asking the openid scope + self.http_client.get_json.reset_mock() + jwks = self.get_success(self.provider.load_jwks(force=True)) + self.http_client.get_json.assert_not_called() + self.assertEqual(jwks, {"keys": []}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" @@ -448,13 +455,7 @@ def test_callback(self): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, - "oidc", - request, - client_redirect_url, - None, - new_user=True, - auth_provider_session_id=None, + expected_user_id, "oidc", request, client_redirect_url, None, new_user=True ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -481,58 +482,17 @@ def test_callback(self): self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, - "oidc", - request, - client_redirect_url, - None, - new_user=False, - auth_provider_session_id=None, + expected_user_id, "oidc", request, client_redirect_url, None, new_user=False ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() - # With an ID token, userinfo fetching and sid in the ID token - self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) - self.provider._exchange_code = simple_async_mock(return_value=token) - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() - self.get_success(self.handler.handle_oidc_callback(request)) - - auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, - "oidc", - request, - client_redirect_url, - None, - new_user=False, - auth_provider_session_id=id_token["sid"], - ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) - self.render_error.assert_not_called() - # Handle userinfo fetching error self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) @@ -816,7 +776,6 @@ def test_extra_attributes(self): client_redirect_url, {"phone": "1234567"}, new_user=True, - auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -831,13 +790,7 @@ def test_map_userinfo_to_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -848,13 +801,7 @@ def test_map_userinfo_to_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -891,26 +838,14 @@ def test_map_userinfo_to_existing_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -925,13 +860,7 @@ def test_map_userinfo_to_existing_user(self): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + user.to_string(), "oidc", ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -967,13 +896,7 @@ def test_map_userinfo_to_existing_user(self): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", - "oidc", - ANY, - ANY, - None, - new_user=False, - auth_provider_session_id=None, + "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -1011,13 +934,7 @@ def test_map_userinfo_to_user_retries(self): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@test_user1:test", "oidc", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -1101,13 +1018,7 @@ def test_attribute_requirements(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@tester:test", "oidc", ANY, ANY, None, new_user=True ) @override_config( @@ -1132,13 +1043,7 @@ def test_attribute_requirements_contains(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", - "oidc", - ANY, - ANY, - None, - new_user=True, - auth_provider_session_id=None, + "@tester:test", "oidc", ANY, ANY, None, new_user=True ) @override_config( @@ -1251,7 +1156,7 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) + provider._exchange_code = simple_async_mock(return_value={}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index e5a6a6c747bf..7b95844b55d3 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -32,7 +32,7 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, UserID from tests import unittest @@ -249,7 +249,7 @@ def test_simple_space(self): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -263,9 +263,7 @@ def test_visibility(self): expected = [(self.space, [self.room]), (self.room, ())] self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) # If the space is made invite-only, it should no longer be viewable. @@ -276,10 +274,7 @@ def test_visibility(self): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure( - self.handler.get_room_hierarchy(create_requester(user2), self.space), - AuthError, - ) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # If the space is made world-readable it should return a result. self.helper.send_state( @@ -291,9 +286,7 @@ def test_visibility(self): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) # Make it not world-readable again and confirm it results in an error. @@ -304,10 +297,7 @@ def test_visibility(self): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure( - self.handler.get_room_hierarchy(create_requester(user2), self.space), - AuthError, - ) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) @@ -315,9 +305,7 @@ def test_visibility(self): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. @@ -326,9 +314,7 @@ def test_visibility(self): AuthError, ) self.get_failure( - self.handler.get_room_hierarchy( - create_requester(user2), "#not-a-space:" + self.hs.hostname - ), + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), AuthError, ) @@ -336,10 +322,10 @@ def test_room_hierarchy_cache(self) -> None: """In-flight room hierarchy requests are deduplicated.""" # Run two `get_room_hierarchy` calls up until they block. deferred1 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) deferred2 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) # Complete the two calls. @@ -354,7 +340,7 @@ def test_room_hierarchy_cache(self) -> None: # A subsequent `get_room_hierarchy` call should not reuse the result. result3 = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result3, expected) self.assertIsNot(result1, result3) @@ -373,11 +359,9 @@ def test_room_hierarchy_cache_sharing(self) -> None: # Run two `get_room_hierarchy` calls for different users up until they block. deferred1 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) - ) - deferred2 = ensureDeferred( - self.handler.get_room_hierarchy(create_requester(user2), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) + deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space)) # Complete the two calls. result1 = self.get_success(deferred1) @@ -481,9 +465,7 @@ def test_filtering(self): ] self._assert_rooms(result, expected) - result = self.get_success( - self.handler.get_room_hierarchy(create_requester(user2), self.space) - ) + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) def test_complex_space(self): @@ -525,7 +507,7 @@ def test_complex_space(self): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -540,9 +522,7 @@ def test_pagination(self): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, limit=7 - ) + self.handler.get_room_hierarchy(self.user, self.space, limit=7) ) # The result should have the space and all of the links, plus some of the # rooms and a pagination token. @@ -554,10 +534,7 @@ def test_pagination(self): # Check the next page. result = self.get_success( self.handler.get_room_hierarchy( - create_requester(self.user), - self.space, - limit=5, - from_token=result["next_batch"], + self.user, self.space, limit=5, from_token=result["next_batch"] ) ) # The result should have the space and the room in it, along with a link @@ -577,22 +554,20 @@ def test_invalid_pagination_token(self): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, limit=7 - ) + self.handler.get_room_hierarchy(self.user, self.space, limit=7) ) self.assertIn("next_batch", result) # Changing the room ID, suggested-only, or max-depth causes an error. self.get_failure( self.handler.get_room_hierarchy( - create_requester(self.user), self.room, from_token=result["next_batch"] + self.user, self.room, from_token=result["next_batch"] ), SynapseError, ) self.get_failure( self.handler.get_room_hierarchy( - create_requester(self.user), + self.user, self.space, suggested_only=True, from_token=result["next_batch"], @@ -601,19 +576,14 @@ def test_invalid_pagination_token(self): ) self.get_failure( self.handler.get_room_hierarchy( - create_requester(self.user), - self.space, - max_depth=0, - from_token=result["next_batch"], + self.user, self.space, max_depth=0, from_token=result["next_batch"] ), SynapseError, ) # An invalid token is ignored. self.get_failure( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, from_token="foo" - ), + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), SynapseError, ) @@ -639,18 +609,14 @@ def test_max_depth(self): # Test just the space itself. result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, max_depth=0 - ) + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) ) expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] self._assert_hierarchy(result, expected) # A single additional layer. result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, max_depth=1 - ) + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) ) expected += [ (rooms[0], ()), @@ -660,9 +626,7 @@ def test_max_depth(self): # A few layers. result = self.get_success( - self.handler.get_room_hierarchy( - create_requester(self.user), self.space, max_depth=3 - ) + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) ) expected += [ (rooms[1], ()), @@ -693,7 +657,7 @@ def test_unknown_room_version(self): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -775,7 +739,7 @@ async def summarize_remote_room_hierarchy(_self, room, suggested_only): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -942,7 +906,7 @@ async def summarize_remote_room_hierarchy(_self, room, suggested_only): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) @@ -1000,7 +964,7 @@ async def summarize_remote_room_hierarchy(_self, room, suggested_only): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(create_requester(self.user), self.space) + self.handler.get_room_hierarchy(self.user, self.space) ) self._assert_hierarchy(result, expected) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 50551aa6e3c2..8cfc184fefc9 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -130,13 +130,7 @@ def test_map_saml_response_to_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "redirect_uri", None, new_user=True ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -162,13 +156,7 @@ def test_map_saml_response_to_existing_user(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -177,13 +165,7 @@ def test_map_saml_response_to_existing_user(self): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "", - None, - new_user=False, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "", None, new_user=False ) def test_map_saml_response_to_invalid_localpart(self): @@ -231,13 +213,7 @@ def test_map_saml_response_to_user_retries(self): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", - "saml", - request, - "", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user1:test", "saml", request, "", None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -333,13 +309,7 @@ def test_attribute_requirements(self): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, + "@test_user:test", "saml", request, "redirect_uri", None, new_user=True ) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index f8cba7b64584..90f800e564b4 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -128,7 +128,6 @@ def prepare(self, reactor, clock, hs): ) self.auth_handler = hs.get_auth_handler() - self.store = hs.get_datastore() def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -409,7 +408,13 @@ def test_remove_unlinked_pushers_background_job(self): self.hs.get_datastore().db_pool.updates._all_done = False # Now let's actually drive the updates to completion - self.wait_for_background_updates() + while not self.get_success( + self.hs.get_datastore().db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.hs.get_datastore().db_pool.updates.do_next_background_update(100), + by=0.1, + ) # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 596ba5a0c9f4..0a6e4795ee92 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -17,7 +17,6 @@ from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client import login, room, sync -from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request @@ -194,10 +193,7 @@ def test_vector_clock_token(self): # # Worker2's event stream position will not advance until we call # __aexit__ again. - worker_store2 = worker_hs2.get_datastore() - assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) - - actx = worker_store2._stream_id_gen.get_next() + actx = worker_hs2.get_datastore()._stream_id_gen.get_next() self.get_success(actx.__aenter__()) response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 3adadcb46bc4..af849bd47138 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import urllib.parse -from http import HTTPStatus from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -41,7 +41,7 @@ def create_test_resource(self): def test_version_string(self): channel = self.make_request("GET", self.url, shorthand=False) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -70,11 +70,11 @@ def test_delete_group(self): content={"localpart": "test"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] - self._check_group(group_id, expect_code=HTTPStatus.OK) + self._check_group(group_id, expect_code=200) # Invite/join another user @@ -82,13 +82,13 @@ def test_delete_group(self): channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -103,10 +103,10 @@ def test_delete_group(self): content={"localpart": "test"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - # Check group returns HTTPStatus.NOT_FOUND - self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND) + # Check group returns 404 + self._check_group(group_id, expect_code=404) # Check users don't think they're in the group self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -122,13 +122,15 @@ def _check_group(self, group_id, expect_code): "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.assertEqual(expect_code, channel.code, msg=channel.json_body) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" channel = self.make_request("GET", b"/joined_groups", access_token=access_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] @@ -208,10 +210,10 @@ def _ensure_quarantined(self, admin_user_tok, server_and_media_id): # Should be quarantined self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, + 404, + int(channel.code), msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s" + "Expected to receive a 404 on accessing quarantined media: %s" % server_and_media_id ), ) @@ -230,8 +232,8 @@ def test_quarantine_media_requires_admin(self): # Expect a forbidden error self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, + 403, + int(channel.result["code"]), msg="Expected forbidden on quarantining media as a non-admin", ) @@ -245,8 +247,8 @@ def test_quarantine_media_requires_admin(self): # Expect a forbidden error self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, + 403, + int(channel.result["code"]), msg="Expected forbidden on quarantining media as a non-admin", ) @@ -277,7 +279,7 @@ def test_quarantine_media_by_id(self): ) # Should be successful - self.assertEqual(HTTPStatus.OK, channel.code) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine the media url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( @@ -290,7 +292,7 @@ def test_quarantine_media_by_id(self): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Attempt to access the media self._ensure_quarantined(admin_user_tok, server_name_and_media_id) @@ -346,9 +348,11 @@ def test_quarantine_all_media_in_room(self, override_url_template=None): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( - channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", ) # Convert mxc URLs to server/media_id strings @@ -392,9 +396,11 @@ def test_quarantine_all_media_by_user(self): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", ) # Attempt to access each piece of media @@ -426,7 +432,7 @@ def test_cannot_quarantine_safe_media(self): url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) channel = self.make_request("POST", url, access_token=admin_user_tok) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( @@ -438,9 +444,11 @@ def test_cannot_quarantine_safe_media(self): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item" + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 1}, + "Expected 1 quarantined item", ) # Attempt to access each piece of media, the first should fail, the @@ -459,10 +467,10 @@ def test_cannot_quarantine_safe_media(self): # Shouldn't be quarantined self.assertEqual( - HTTPStatus.OK, - channel.code, + 200, + int(channel.code), msg=( - "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s" + "Expected to receive a 200 on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) @@ -491,7 +499,7 @@ def prepare(self, reactor, clock, hs): def test_purge_history(self): """ Simple test of purge history API. - Test only that is is possible to call, get status HTTPStatus.OK and purge_id. + Test only that is is possible to call, get status 200 and purge_id. """ channel = self.make_request( @@ -501,7 +509,7 @@ def test_purge_history(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("purge_id", channel.json_body) purge_id = channel.json_body["purge_id"] @@ -512,5 +520,5 @@ def test_purge_history(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 4d152c0d66c2..cd5c60b65c1f 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -16,14 +16,11 @@ from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login from synapse.server import HomeServer from synapse.storage.background_updates import BackgroundUpdater -from synapse.util import Clock from tests import unittest @@ -34,7 +31,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs: HomeServer): self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -47,9 +44,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ("POST", "/_synapse/admin/v1/background_updates/start_job"), ] ) - def test_requester_is_no_admin(self, method: str, url: str) -> None: + def test_requester_is_no_admin(self, method: str, url: str): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ self.register_user("user", "pass", admin=False) @@ -65,7 +62,7 @@ def test_requester_is_no_admin(self, method: str, url: str) -> None: self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -93,7 +90,7 @@ def test_invalid_parameter(self) -> None: self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def _register_bg_update(self) -> None: + def _register_bg_update(self): "Adds a bg update but doesn't start it" async def _fake_update(progress, batch_size) -> int: @@ -115,7 +112,7 @@ async def _fake_update(progress, batch_size) -> int: ) ) - def test_status_empty(self) -> None: + def test_status_empty(self): """Test the status API works.""" channel = self.make_request( @@ -130,7 +127,7 @@ def test_status_empty(self) -> None: channel.json_body, {"current_updates": {}, "enabled": True} ) - def test_status_bg_update(self) -> None: + def test_status_bg_update(self): """Test the status API works with a background update.""" # Create a new background update @@ -138,7 +135,7 @@ def test_status_bg_update(self) -> None: self._register_bg_update() self.store.db_pool.updates.start_doing_background_updates() - self.reactor.pump([1.0, 1.0, 1.0]) + self.reactor.pump([1.0, 1.0]) channel = self.make_request( "GET", @@ -165,7 +162,7 @@ def test_status_bg_update(self) -> None: }, ) - def test_enabled(self) -> None: + def test_enabled(self): """Test the enabled API works.""" # Create a new background update @@ -302,7 +299,7 @@ def test_enabled(self) -> None: ), ] ) - def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None: + def test_start_backround_job(self, job_name: str, updates: Collection[str]): """ Test that background updates add to database and be processed. @@ -344,7 +341,7 @@ def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> N ) ) - def test_start_backround_job_twice(self) -> None: + def test_start_backround_job_twice(self): """Test that add a background update twice return an error.""" # add job to database diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index f7080bda8796..a3679be20539 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import urllib.parse -from http import HTTPStatus from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest @@ -34,7 +30,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -51,21 +47,17 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_no_auth(self, method: str) -> None: + def test_no_auth(self, method: str): """ Try to get a device of an user without authentication. """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_requester_is_no_admin(self, method: str) -> None: + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ @@ -75,17 +67,13 @@ def test_requester_is_no_admin(self, method: str) -> None: access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_does_not_exist(self, method: str) -> None: + def test_user_does_not_exist(self, method: str): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = ( "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" @@ -98,13 +86,13 @@ def test_user_does_not_exist(self, method: str) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_is_not_local(self, method: str) -> None: + def test_user_is_not_local(self, method: str): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" @@ -117,12 +105,12 @@ def test_user_is_not_local(self, method: str) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_device(self) -> None: + def test_unknown_device(self): """ - Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. + Tests that a lookup for a device that does not exist returns either 404 or 200. """ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( self.other_user @@ -134,7 +122,7 @@ def test_unknown_device(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) channel = self.make_request( @@ -143,7 +131,7 @@ def test_unknown_device(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) channel = self.make_request( "DELETE", @@ -151,10 +139,10 @@ def test_unknown_device(self) -> None: access_token=self.admin_user_tok, ) - # Delete unknown device returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown device returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) - def test_update_device_too_long_display_name(self) -> None: + def test_update_device_too_long_display_name(self): """ Update a device with a display name that is invalid (too long). """ @@ -179,7 +167,7 @@ def test_update_device_too_long_display_name(self) -> None: content=update, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. @@ -189,12 +177,12 @@ def test_update_device_too_long_display_name(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_no_display_name(self) -> None: + def test_update_no_display_name(self): """ - Tests that a update for a device without JSON returns a HTTPStatus.OK + Tests that a update for a device without JSON returns a 200 """ # Set iniital display name. update = {"display_name": "new display"} @@ -210,7 +198,7 @@ def test_update_no_display_name(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. channel = self.make_request( @@ -219,10 +207,10 @@ def test_update_no_display_name(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_display_name(self) -> None: + def test_update_display_name(self): """ Tests a normal successful update of display name """ @@ -234,7 +222,7 @@ def test_update_display_name(self) -> None: content={"display_name": "new displayname"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name channel = self.make_request( @@ -243,10 +231,10 @@ def test_update_display_name(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) - def test_get_device(self) -> None: + def test_get_device(self): """ Tests that a normal lookup for a device is successfully """ @@ -256,7 +244,7 @@ def test_get_device(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) # Check that all fields are available self.assertIn("user_id", channel.json_body) @@ -265,7 +253,7 @@ def test_get_device(self) -> None: self.assertIn("last_seen_ip", channel.json_body) self.assertIn("last_seen_ts", channel.json_body) - def test_delete_device(self) -> None: + def test_delete_device(self): """ Tests that a remove of a device is successfully """ @@ -281,7 +269,7 @@ def test_delete_device(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure that the number of devices is decreased res = self.get_success(self.handler.get_devices_by_user(self.other_user)) @@ -295,7 +283,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -305,20 +293,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.other_user ) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to list devices of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -330,16 +314,12 @@ def test_requester_is_no_admin(self) -> None: access_token=other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self) -> None: + def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" channel = self.make_request( @@ -348,12 +328,12 @@ def test_user_does_not_exist(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self) -> None: + def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" @@ -363,10 +343,10 @@ def test_user_is_not_local(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_user_has_no_devices(self) -> None: + def test_user_has_no_devices(self): """ Tests that a normal lookup for devices is successfully if user has no devices @@ -379,11 +359,11 @@ def test_user_has_no_devices(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) - def test_get_devices(self) -> None: + def test_get_devices(self): """ Tests that a normal lookup for devices is successfully """ @@ -399,7 +379,7 @@ def test_get_devices(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) @@ -419,7 +399,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -431,20 +411,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.other_user ) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete devices of an user without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -456,16 +432,12 @@ def test_requester_is_no_admin(self) -> None: access_token=other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self) -> None: + def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" channel = self.make_request( @@ -474,12 +446,12 @@ def test_user_does_not_exist(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self) -> None: + def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" @@ -489,12 +461,12 @@ def test_user_is_not_local(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_devices(self) -> None: + def test_unknown_devices(self): """ - Tests that a remove of a device that does not exist returns HTTPStatus.OK. + Tests that a remove of a device that does not exist returns 200. """ channel = self.make_request( "POST", @@ -503,10 +475,10 @@ def test_unknown_devices(self) -> None: content={"devices": ["unknown_device1", "unknown_device2"]}, ) - # Delete unknown devices returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown devices returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) - def test_delete_devices(self) -> None: + def test_delete_devices(self): """ Tests that a remove of devices is successfully """ @@ -533,7 +505,7 @@ def test_delete_devices(self) -> None: content={"devices": device_ids}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) res = self.get_success(self.handler.get_devices_by_user(self.other_user)) self.assertEqual(0, len(res)) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 4f89f8b534ff..e9ef89731ffe 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -11,17 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus -from typing import List -from twisted.test.proto_helpers import MemoryReactor +import json import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, report_event, room -from synapse.server import HomeServer -from synapse.types import JsonDict -from synapse.util import Clock from tests import unittest @@ -34,7 +29,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -75,22 +70,18 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/_synapse/admin/v1/event_reports" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to get an event report without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -99,14 +90,10 @@ def test_requester_is_no_admin(self) -> None: access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self) -> None: + def test_default_success(self): """ Testing list of reported events """ @@ -117,13 +104,13 @@ def test_default_success(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit(self) -> None: + def test_limit(self): """ Testing list of reported events with limit """ @@ -134,13 +121,13 @@ def test_limit(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["event_reports"]) - def test_from(self) -> None: + def test_from(self): """ Testing list of reported events with a defined starting point (from) """ @@ -151,13 +138,13 @@ def test_from(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit_and_from(self) -> None: + def test_limit_and_from(self): """ Testing list of reported events with a defined starting point and limit """ @@ -168,13 +155,13 @@ def test_limit_and_from(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["event_reports"]), 10) self._check_fields(channel.json_body["event_reports"]) - def test_filter_room(self) -> None: + def test_filter_room(self): """ Testing list of reported events with a filter of room """ @@ -185,7 +172,7 @@ def test_filter_room(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -194,7 +181,7 @@ def test_filter_room(self) -> None: for report in channel.json_body["event_reports"]: self.assertEqual(report["room_id"], self.room_id1) - def test_filter_user(self) -> None: + def test_filter_user(self): """ Testing list of reported events with a filter of user """ @@ -205,7 +192,7 @@ def test_filter_user(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -214,7 +201,7 @@ def test_filter_user(self) -> None: for report in channel.json_body["event_reports"]: self.assertEqual(report["user_id"], self.other_user) - def test_filter_user_and_room(self) -> None: + def test_filter_user_and_room(self): """ Testing list of reported events with a filter of user and room """ @@ -225,7 +212,7 @@ def test_filter_user_and_room(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -235,7 +222,7 @@ def test_filter_user_and_room(self) -> None: self.assertEqual(report["user_id"], self.other_user) self.assertEqual(report["room_id"], self.room_id1) - def test_valid_search_order(self) -> None: + def test_valid_search_order(self): """ Testing search order. Order by timestamps. """ @@ -247,7 +234,7 @@ def test_valid_search_order(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -265,7 +252,7 @@ def test_valid_search_order(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -276,9 +263,9 @@ def test_valid_search_order(self) -> None: ) report += 1 - def test_invalid_search_order(self) -> None: + def test_invalid_search_order(self): """ - Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST + Testing that a invalid search order returns a 400 """ channel = self.make_request( @@ -287,17 +274,13 @@ def test_invalid_search_order(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) - def test_limit_is_negative(self) -> None: + def test_limit_is_negative(self): """ - Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative limit parameter returns a 400 """ channel = self.make_request( @@ -306,16 +289,12 @@ def test_limit_is_negative(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self) -> None: + def test_from_is_negative(self): """ - Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative from parameter returns a 400 """ channel = self.make_request( @@ -324,14 +303,10 @@ def test_from_is_negative(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_next_token(self) -> None: + def test_next_token(self): """ Testing that `next_token` appears at the right place """ @@ -344,7 +319,7 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -357,7 +332,7 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -370,7 +345,7 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -384,12 +359,12 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) - def _create_event_and_report(self, room_id: str, user_tok: str) -> None: + def _create_event_and_report(self, room_id, user_tok): """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -397,14 +372,12 @@ def _create_event_and_report(self, room_id: str, user_tok: str) -> None: channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - {"score": -100, "reason": "this makes me sad"}, + json.dumps({"score": -100, "reason": "this makes me sad"}), access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - def _create_event_and_report_without_parameters( - self, room_id: str, user_tok: str - ) -> None: + def _create_event_and_report_without_parameters(self, room_id, user_tok): """Create and report an event, but omit reason and score""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -412,12 +385,12 @@ def _create_event_and_report_without_parameters( channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - {}, + json.dumps({}), access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - def _check_fields(self, content: List[JsonDict]) -> None: + def _check_fields(self, content): """Checks that all attributes are present in an event report""" for c in content: self.assertIn("id", c) @@ -440,7 +413,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -460,22 +433,18 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # first created event report gets `id`=2 self.url = "/_synapse/admin/v1/event_reports/2" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to get event report without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -484,14 +453,10 @@ def test_requester_is_no_admin(self) -> None: access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self) -> None: + def test_default_success(self): """ Testing get a reported event """ @@ -502,12 +467,12 @@ def test_default_success(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._check_fields(channel.json_body) - def test_invalid_report_id(self) -> None: + def test_invalid_report_id(self): """ - Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. + Testing that an invalid `report_id` returns a 400. """ # `report_id` is negative @@ -517,11 +482,7 @@ def test_invalid_report_id(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -535,11 +496,7 @@ def test_invalid_report_id(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -553,20 +510,16 @@ def test_invalid_report_id(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", channel.json_body["error"], ) - def test_report_id_not_found(self) -> None: + def test_report_id_not_found(self): """ - Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. + Testing that a not existing `report_id` returns a 404. """ channel = self.make_request( @@ -575,15 +528,11 @@ def test_report_id_not_found(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) - def _create_event_and_report(self, room_id: str, user_tok: str) -> None: + def _create_event_and_report(self, room_id, user_tok): """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -591,12 +540,12 @@ def _create_event_and_report(self, room_id: str, user_tok: str) -> None: channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - {"score": -100, "reason": "this makes me sad"}, + json.dumps({"score": -100, "reason": "this makes me sad"}), access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - def _check_fields(self, content: JsonDict) -> None: + def _check_fields(self, content): """Checks that all attributes are present in a event report""" self.assertIn("id", content) self.assertIn("received_ts", content) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py deleted file mode 100644 index 5188499ef2d6..000000000000 --- a/tests/rest/admin/test_federation.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from http import HTTPStatus -from typing import List, Optional - -from parameterized import parameterized - -import synapse.rest.admin -from synapse.api.errors import Codes -from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.types import JsonDict - -from tests import unittest - - -class FederationTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, hs: HomeServer): - self.store = hs.get_datastore() - self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.url = "/_synapse/admin/v1/federation/destinations" - - @parameterized.expand( - [ - ("/_synapse/admin/v1/federation/destinations",), - ("/_synapse/admin/v1/federation/destinations/dummy",), - ] - ) - def test_requester_is_no_admin(self, url: str): - """ - If the user is not a server admin, an error 403 is returned. - """ - - self.register_user("user", "pass", admin=False) - other_user_tok = self.login("user", "pass") - - channel = self.make_request( - "GET", - url, - content={}, - access_token=other_user_tok, - ) - - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_invalid_parameter(self): - """ - If parameters are invalid, an error is returned. - """ - - # negative limit - channel = self.make_request( - "GET", - self.url + "?limit=-5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - - # negative from - channel = self.make_request( - "GET", - self.url + "?from=-5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - - # unkown order_by - channel = self.make_request( - "GET", - self.url + "?order_by=bar", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - - # invalid search order - channel = self.make_request( - "GET", - self.url + "?dir=bar", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - - # invalid destination - channel = self.make_request( - "GET", - self.url + "/dummy", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_limit(self): - """ - Testing list of destinations with limit - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url + "?limit=5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 5) - self.assertEqual(channel.json_body["next_token"], "5") - self._check_fields(channel.json_body["destinations"]) - - def test_from(self): - """ - Testing list of destinations with a defined starting point (from) - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url + "?from=5", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 15) - self.assertNotIn("next_token", channel.json_body) - self._check_fields(channel.json_body["destinations"]) - - def test_limit_and_from(self): - """ - Testing list of destinations with a defined starting point and limit - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url + "?from=5&limit=10", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(channel.json_body["next_token"], "15") - self.assertEqual(len(channel.json_body["destinations"]), 10) - self._check_fields(channel.json_body["destinations"]) - - def test_next_token(self): - """ - Testing that `next_token` appears at the right place - """ - - number_destinations = 20 - self._create_destinations(number_destinations) - - # `next_token` does not appear - # Number of results is the number of entries - channel = self.make_request( - "GET", - self.url + "?limit=20", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), number_destinations) - self.assertNotIn("next_token", channel.json_body) - - # `next_token` does not appear - # Number of max results is larger than the number of entries - channel = self.make_request( - "GET", - self.url + "?limit=21", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), number_destinations) - self.assertNotIn("next_token", channel.json_body) - - # `next_token` does appear - # Number of max results is smaller than the number of entries - channel = self.make_request( - "GET", - self.url + "?limit=19", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 19) - self.assertEqual(channel.json_body["next_token"], "19") - - # Check - # Set `from` to value of `next_token` for request remaining entries - # `next_token` does not appear - channel = self.make_request( - "GET", - self.url + "?from=19", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], number_destinations) - self.assertEqual(len(channel.json_body["destinations"]), 1) - self.assertNotIn("next_token", channel.json_body) - - def test_list_all_destinations(self): - """ - List all destinations. - """ - number_destinations = 5 - self._create_destinations(number_destinations) - - channel = self.make_request( - "GET", - self.url, - {}, - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(number_destinations, len(channel.json_body["destinations"])) - self.assertEqual(number_destinations, channel.json_body["total"]) - - # Check that all fields are available - self._check_fields(channel.json_body["destinations"]) - - def test_order_by(self): - """ - Testing order list with parameter `order_by` - """ - - def _order_test( - expected_destination_list: List[str], - order_by: Optional[str], - dir: Optional[str] = None, - ): - """Request the list of destinations in a certain order. - Assert that order is what we expect - - Args: - expected_destination_list: The list of user_id in the order - we expect to get back from the server - order_by: The type of ordering to give the server - dir: The direction of ordering to give the server - """ - - url = f"{self.url}?" - if order_by is not None: - url += f"order_by={order_by}&" - if dir is not None and dir in ("b", "f"): - url += f"dir={dir}" - channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(channel.json_body["total"], len(expected_destination_list)) - - returned_order = [ - row["destination"] for row in channel.json_body["destinations"] - ] - self.assertEqual(expected_destination_list, returned_order) - self._check_fields(channel.json_body["destinations"]) - - # create destinations - dest = [ - ("sub-a.example.com", 100, 300, 200, 300), - ("sub-b.example.com", 200, 200, 100, 100), - ("sub-c.example.com", 300, 100, 300, 200), - ] - for ( - destination, - failure_ts, - retry_last_ts, - retry_interval, - last_successful_stream_ordering, - ) in dest: - self.get_success( - self.store.set_destination_retry_timings( - destination, failure_ts, retry_last_ts, retry_interval - ) - ) - self.get_success( - self.store.set_destination_last_successful_stream_ordering( - destination, last_successful_stream_ordering - ) - ) - - # order by default (destination) - _order_test([dest[0][0], dest[1][0], dest[2][0]], None) - _order_test([dest[0][0], dest[1][0], dest[2][0]], None, "f") - _order_test([dest[2][0], dest[1][0], dest[0][0]], None, "b") - - # order by destination - _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination") - _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination", "f") - _order_test([dest[2][0], dest[1][0], dest[0][0]], "destination", "b") - - # order by failure_ts - _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts") - _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts", "f") - _order_test([dest[2][0], dest[1][0], dest[0][0]], "failure_ts", "b") - - # order by retry_last_ts - _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts") - _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts", "f") - _order_test([dest[0][0], dest[1][0], dest[2][0]], "retry_last_ts", "b") - - # order by retry_interval - _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval") - _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval", "f") - _order_test([dest[2][0], dest[0][0], dest[1][0]], "retry_interval", "b") - - # order by last_successful_stream_ordering - _order_test( - [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering" - ) - _order_test( - [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering", "f" - ) - _order_test( - [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b" - ) - - def test_search_term(self): - """Test that searching for a destination works correctly""" - - def _search_test( - expected_destination: Optional[str], - search_term: str, - ): - """Search for a destination and check that the returned destinationis a match - - Args: - expected_destination: The room_id expected to be returned by the API. - Set to None to expect zero results for the search - search_term: The term to search for room names with - """ - url = f"{self.url}?destination={search_term}" - channel = self.make_request( - "GET", - url.encode("ascii"), - access_token=self.admin_user_tok, - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - # Check that destinations were returned - self.assertTrue("destinations" in channel.json_body) - self._check_fields(channel.json_body["destinations"]) - destinations = channel.json_body["destinations"] - - # Check that the expected number of destinations were returned - expected_destination_count = 1 if expected_destination else 0 - self.assertEqual(len(destinations), expected_destination_count) - self.assertEqual(channel.json_body["total"], expected_destination_count) - - if expected_destination: - # Check that the first returned destination is correct - self.assertEqual(expected_destination, destinations[0]["destination"]) - - number_destinations = 3 - self._create_destinations(number_destinations) - - # Test searching - _search_test("sub0.example.com", "0") - _search_test("sub0.example.com", "sub0") - - _search_test("sub1.example.com", "1") - _search_test("sub1.example.com", "1.") - - # Test case insensitive - _search_test("sub0.example.com", "SUB0") - - _search_test(None, "foo") - _search_test(None, "bar") - - def test_get_single_destination(self): - """ - Get one specific destinations. - """ - self._create_destinations(5) - - channel = self.make_request( - "GET", - self.url + "/sub0.example.com", - access_token=self.admin_user_tok, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual("sub0.example.com", channel.json_body["destination"]) - - # Check that all fields are available - # convert channel.json_body into a List - self._check_fields([channel.json_body]) - - def _create_destinations(self, number_destinations: int): - """Create a number of destinations - - Args: - number_destinations: Number of destinations to be created - """ - for i in range(0, number_destinations): - dest = f"sub{i}.example.com" - self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) - self.get_success( - self.store.set_destination_last_successful_stream_ordering(dest, 100) - ) - - def _check_fields(self, content: List[JsonDict]): - """Checks that the expected destination attributes are present in content - - Args: - content: List that is checked for content - """ - for c in content: - self.assertIn("destination", c) - self.assertIn("retry_last_ts", c) - self.assertIn("retry_interval", c) - self.assertIn("failure_ts", c) - self.assertIn("last_successful_stream_ordering", c) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 81e578fd26c1..db0e78c03995 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -12,19 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import json import os -from http import HTTPStatus from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -42,7 +39,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -51,7 +48,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete media without authentication. """ @@ -59,14 +56,10 @@ def test_no_auth(self) -> None: channel = self.make_request("DELETE", url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -81,16 +74,12 @@ def test_requester_is_no_admin(self) -> None: access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_does_not_exist(self) -> None: + def test_media_does_not_exist(self): """ - Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a media that does not exist returns a 404 """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") @@ -100,12 +89,12 @@ def test_media_does_not_exist(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_media_is_not_local(self) -> None: + def test_media_is_not_local(self): """ - Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a media that is not a local returns a 400 """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") @@ -115,10 +104,10 @@ def test_media_is_not_local(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_delete_media(self) -> None: + def test_delete_media(self): """ Tests that delete a media is successfully """ @@ -128,10 +117,7 @@ def test_delete_media(self) -> None: # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -151,11 +137,10 @@ def test_delete_media(self) -> None: # Should be successful self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" - % server_and_media_id + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -172,7 +157,7 @@ def test_delete_media(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -189,10 +174,10 @@ def test_delete_media(self) -> None: access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % server_and_media_id ), ) @@ -211,7 +196,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -224,21 +209,17 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Move clock up to somewhat realistic time self.reactor.advance(1000000000) - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete media without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -251,16 +232,12 @@ def test_requester_is_no_admin(self) -> None: access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_is_not_local(self) -> None: + def test_media_is_not_local(self): """ - Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for media that is not local returns a 400 """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" @@ -270,10 +247,10 @@ def test_media_is_not_local(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_missing_parameter(self) -> None: + def test_missing_parameter(self): """ If the parameter `before_ts` is missing, an error is returned. """ @@ -283,17 +260,13 @@ def test_missing_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -303,11 +276,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -320,11 +289,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " @@ -338,11 +303,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter size_gt must be a string representing a positive integer.", @@ -355,18 +316,14 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], ) - def test_delete_media_never_accessed(self) -> None: + def test_delete_media_never_accessed(self): """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` @@ -388,7 +345,7 @@ def test_delete_media_never_accessed(self) -> None: self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -397,7 +354,7 @@ def test_delete_media_never_accessed(self) -> None: self._access_media(server_and_media_id, False) - def test_keep_media_by_date(self) -> None: + def test_keep_media_by_date(self): """ Tests that media is not deleted if it is newer than `before_ts` """ @@ -413,7 +370,7 @@ def test_keep_media_by_date(self) -> None: self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -425,7 +382,7 @@ def test_keep_media_by_date(self) -> None: self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -434,7 +391,7 @@ def test_keep_media_by_date(self) -> None: self._access_media(server_and_media_id, False) - def test_keep_media_by_size(self) -> None: + def test_keep_media_by_size(self): """ Tests that media is not deleted if its size is smaller than or equal to `size_gt` @@ -449,7 +406,7 @@ def test_keep_media_by_size(self) -> None: self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -460,7 +417,7 @@ def test_keep_media_by_size(self) -> None: self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -469,7 +426,7 @@ def test_keep_media_by_size(self) -> None: self._access_media(server_and_media_id, False) - def test_keep_media_by_user_avatar(self) -> None: + def test_keep_media_by_user_avatar(self): """ Tests that we do not delete media if is used as a user avatar Tests parameter `keep_profiles` @@ -482,10 +439,10 @@ def test_keep_media_by_user_avatar(self) -> None: channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.admin_user,), - content={"avatar_url": "mxc://%s" % (server_and_media_id,)}, + content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -493,7 +450,7 @@ def test_keep_media_by_user_avatar(self) -> None: self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -504,7 +461,7 @@ def test_keep_media_by_user_avatar(self) -> None: self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -513,7 +470,7 @@ def test_keep_media_by_user_avatar(self) -> None: self._access_media(server_and_media_id, False) - def test_keep_media_by_room_avatar(self) -> None: + def test_keep_media_by_room_avatar(self): """ Tests that we do not delete media if it is used as a room avatar Tests parameter `keep_profiles` @@ -527,10 +484,10 @@ def test_keep_media_by_room_avatar(self) -> None: channel = self.make_request( "PUT", "/rooms/%s/state/m.room.avatar" % (room_id,), - content={"url": "mxc://%s" % (server_and_media_id,)}, + content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -538,7 +495,7 @@ def test_keep_media_by_room_avatar(self) -> None: self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -549,7 +506,7 @@ def test_keep_media_by_room_avatar(self) -> None: self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -558,7 +515,7 @@ def test_keep_media_by_room_avatar(self) -> None: self._access_media(server_and_media_id, False) - def _create_media(self) -> str: + def _create_media(self): """ Create a media and return media_id and server_and_media_id """ @@ -566,10 +523,7 @@ def _create_media(self) -> str: # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -580,7 +534,7 @@ def _create_media(self) -> str: return server_and_media_id - def _access_media(self, server_and_media_id, expect_success=True) -> None: + def _access_media(self, server_and_media_id, expect_success=True): """ Try to access a media and check the result """ @@ -600,10 +554,10 @@ def _access_media(self, server_and_media_id, expect_success=True) -> None: if expect_success: self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -611,10 +565,10 @@ def _access_media(self, server_and_media_id, expect_success=True) -> None: self.assertTrue(os.path.exists(local_path)) else: self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % (server_and_media_id) ), ) @@ -630,7 +584,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() self.server_name = hs.hostname @@ -643,10 +597,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -655,7 +606,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/_synapse/admin/v1/media/%s/%s/%s" @parameterized.expand(["quarantine", "unquarantine"]) - def test_no_auth(self, action: str) -> None: + def test_no_auth(self, action: str): """ Try to protect media without authentication. """ @@ -666,15 +617,11 @@ def test_no_auth(self, action: str) -> None: b"{}", ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) - def test_requester_is_no_admin(self, action: str) -> None: + def test_requester_is_no_admin(self, action: str): """ If the user is not a server admin, an error is returned. """ @@ -687,14 +634,10 @@ def test_requester_is_no_admin(self, action: str) -> None: access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_quarantine_media(self) -> None: + def test_quarantine_media(self): """ Tests that quarantining and remove from quarantine a media is successfully """ @@ -709,7 +652,7 @@ def test_quarantine_media(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -722,13 +665,13 @@ def test_quarantine_media(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) self.assertFalse(media_info["quarantined_by"]) - def test_quarantine_protected_media(self) -> None: + def test_quarantine_protected_media(self): """ Tests that quarantining from protected media fails """ @@ -747,7 +690,7 @@ def test_quarantine_protected_media(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) # verify that is not in quarantine @@ -763,7 +706,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() @@ -775,10 +718,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Upload some media into the room response = self.helper.upload_media( - upload_resource, - SMALL_PNG, - tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -787,22 +727,18 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/_synapse/admin/v1/media/%s/%s" @parameterized.expand(["protect", "unprotect"]) - def test_no_auth(self, action: str) -> None: + def test_no_auth(self, action: str): """ Try to protect media without authentication. """ channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) - def test_requester_is_no_admin(self, action: str) -> None: + def test_requester_is_no_admin(self, action: str): """ If the user is not a server admin, an error is returned. """ @@ -815,14 +751,10 @@ def test_requester_is_no_admin(self, action: str) -> None: access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_protect_media(self) -> None: + def test_protect_media(self): """ Tests that protect and unprotect a media is successfully """ @@ -837,7 +769,7 @@ def test_protect_media(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -850,7 +782,7 @@ def test_protect_media(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -867,7 +799,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -877,21 +809,17 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/purge_media_cache" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to delete media without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self) -> None: + def test_requester_is_not_admin(self): """ If the user is not a server admin, an error is returned. """ @@ -904,14 +832,10 @@ def test_requester_is_not_admin(self) -> None: access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -921,11 +845,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -938,11 +858,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 350a62dda672..9bac423ae048 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -11,17 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import random import string -from http import HTTPStatus - -from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest @@ -32,7 +28,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -42,7 +38,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/_synapse/admin/v1/registration_tokens" - def _new_token(self, **kwargs) -> str: + def _new_token(self, **kwargs): """Helper function to create a token.""" token = kwargs.get( "token", @@ -64,17 +60,13 @@ def _new_token(self, **kwargs) -> str: # CREATION - def test_create_no_auth(self) -> None: + def test_create_no_auth(self): """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_create_requester_not_admin(self) -> None: + def test_create_requester_not_admin(self): """Try to create a token while not an admin.""" channel = self.make_request( "POST", @@ -82,14 +74,10 @@ def test_create_requester_not_admin(self) -> None: {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_create_using_defaults(self) -> None: + def test_create_using_defaults(self): """Create a token using all the defaults.""" channel = self.make_request( "POST", @@ -98,14 +86,14 @@ def test_create_using_defaults(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_specifying_fields(self) -> None: + def test_create_specifying_fields(self): """Create a token specifying the value of all fields.""" # As many of the allowed characters as possible with length <= 64 token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" @@ -122,14 +110,14 @@ def test_create_specifying_fields(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_with_null_value(self) -> None: + def test_create_with_null_value(self): """Create a token specifying unlimited uses and no expiry.""" data = { "uses_allowed": None, @@ -143,14 +131,14 @@ def test_create_with_null_value(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_token_too_long(self) -> None: + def test_create_token_too_long(self): """Check token longer than 64 chars is invalid.""" data = {"token": "a" * 65} @@ -161,14 +149,10 @@ def test_create_token_too_long(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_invalid_chars(self) -> None: + def test_create_token_invalid_chars(self): """Check you can't create token with invalid characters.""" data = { "token": "abc/def", @@ -181,14 +165,10 @@ def test_create_token_invalid_chars(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_already_exists(self) -> None: + def test_create_token_already_exists(self): """Check you can't create token that already exists.""" data = { "token": "abcd", @@ -200,7 +180,7 @@ def test_create_token_already_exists(self) -> None: data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) channel2 = self.make_request( "POST", @@ -208,10 +188,10 @@ def test_create_token_already_exists(self) -> None: data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_unable_to_generate_token(self) -> None: + def test_create_unable_to_generate_token(self): """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] @@ -240,9 +220,9 @@ def test_create_unable_to_generate_token(self) -> None: {"length": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(500, channel.code, msg=channel.json_body) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) - def test_create_uses_allowed(self) -> None: + def test_create_uses_allowed(self): """Check you can only create a token with good values for uses_allowed.""" # Should work with 0 (token is invalid from the start) channel = self.make_request( @@ -251,7 +231,7 @@ def test_create_uses_allowed(self) -> None: {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -261,11 +241,7 @@ def test_create_uses_allowed(self) -> None: {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -275,14 +251,10 @@ def test_create_uses_allowed(self) -> None: {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_expiry_time(self) -> None: + def test_create_expiry_time(self): """Check you can't create a token with an invalid expiry_time.""" # Should fail with a time in the past channel = self.make_request( @@ -291,11 +263,7 @@ def test_create_expiry_time(self) -> None: {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -305,14 +273,10 @@ def test_create_expiry_time(self) -> None: {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_length(self) -> None: + def test_create_length(self): """Check you can only generate a token with a valid length.""" # Should work with 64 channel = self.make_request( @@ -321,7 +285,7 @@ def test_create_length(self) -> None: {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -331,11 +295,7 @@ def test_create_length(self) -> None: {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -345,11 +305,7 @@ def test_create_length(self) -> None: {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -359,11 +315,7 @@ def test_create_length(self) -> None: {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -373,30 +325,22 @@ def test_create_length(self) -> None: {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING - def test_update_no_auth(self) -> None: + def test_update_no_auth(self): """Try to update a token without authentication.""" channel = self.make_request( "PUT", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_update_requester_not_admin(self) -> None: + def test_update_requester_not_admin(self): """Try to update a token while not an admin.""" channel = self.make_request( "PUT", @@ -404,14 +348,10 @@ def test_update_requester_not_admin(self) -> None: {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_update_non_existent(self) -> None: + def test_update_non_existent(self): """Try to update a token that doesn't exist.""" channel = self.make_request( "PUT", @@ -420,14 +360,10 @@ def test_update_non_existent(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_update_uses_allowed(self) -> None: + def test_update_uses_allowed(self): """Test updating just uses_allowed.""" # Create new token using default values token = self._new_token() @@ -439,7 +375,7 @@ def test_update_uses_allowed(self) -> None: {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -450,7 +386,7 @@ def test_update_uses_allowed(self) -> None: {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -461,7 +397,7 @@ def test_update_uses_allowed(self) -> None: {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -472,11 +408,7 @@ def test_update_uses_allowed(self) -> None: {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -486,14 +418,10 @@ def test_update_uses_allowed(self) -> None: {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_expiry_time(self) -> None: + def test_update_expiry_time(self): """Test updating just expiry_time.""" # Create new token using default values token = self._new_token() @@ -506,7 +434,7 @@ def test_update_expiry_time(self) -> None: {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -517,7 +445,7 @@ def test_update_expiry_time(self) -> None: {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -529,11 +457,7 @@ def test_update_expiry_time(self) -> None: {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -543,14 +467,10 @@ def test_update_expiry_time(self) -> None: {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_both(self) -> None: + def test_update_both(self): """Test updating both uses_allowed and expiry_time.""" # Create new token using default values token = self._new_token() @@ -568,11 +488,11 @@ def test_update_both(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) - def test_update_invalid_type(self) -> None: + def test_update_invalid_type(self): """Test using invalid types doesn't work.""" # Create new token using default values token = self._new_token() @@ -589,30 +509,22 @@ def test_update_invalid_type(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING - def test_delete_no_auth(self) -> None: + def test_delete_no_auth(self): """Try to delete a token without authentication.""" channel = self.make_request( "DELETE", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_delete_requester_not_admin(self) -> None: + def test_delete_requester_not_admin(self): """Try to delete a token while not an admin.""" channel = self.make_request( "DELETE", @@ -620,14 +532,10 @@ def test_delete_requester_not_admin(self) -> None: {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_delete_non_existent(self) -> None: + def test_delete_non_existent(self): """Try to delete a token that doesn't exist.""" channel = self.make_request( "DELETE", @@ -636,14 +544,10 @@ def test_delete_non_existent(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_delete(self) -> None: + def test_delete(self): """Test deleting a token.""" # Create new token using default values token = self._new_token() @@ -655,25 +559,21 @@ def test_delete(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # GETTING ONE - def test_get_no_auth(self) -> None: + def test_get_no_auth(self): """Try to get a token without authentication.""" channel = self.make_request( "GET", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_get_requester_not_admin(self) -> None: + def test_get_requester_not_admin(self): """Try to get a token while not an admin.""" channel = self.make_request( "GET", @@ -681,14 +581,10 @@ def test_get_requester_not_admin(self) -> None: {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_get_non_existent(self) -> None: + def test_get_non_existent(self): """Try to get a token that doesn't exist.""" channel = self.make_request( "GET", @@ -697,14 +593,10 @@ def test_get_non_existent(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_get(self) -> None: + def test_get(self): """Test getting a token.""" # Create new token using default values token = self._new_token() @@ -716,7 +608,7 @@ def test_get(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -725,17 +617,13 @@ def test_get(self) -> None: # LISTING - def test_list_no_auth(self) -> None: + def test_list_no_auth(self): """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_list_requester_not_admin(self) -> None: + def test_list_requester_not_admin(self): """Try to list tokens while not an admin.""" channel = self.make_request( "GET", @@ -743,14 +631,10 @@ def test_list_requester_not_admin(self) -> None: {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_list_all(self) -> None: + def test_list_all(self): """Test listing all tokens.""" # Create new token using default values token = self._new_token() @@ -762,7 +646,7 @@ def test_list_all(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -771,7 +655,7 @@ def test_list_all(self) -> None: self.assertEqual(token_info["pending"], 0) self.assertEqual(token_info["completed"], 0) - def test_list_invalid_query_parameter(self) -> None: + def test_list_invalid_query_parameter(self): """Test with `valid` query parameter not `true` or `false`.""" channel = self.make_request( "GET", @@ -780,13 +664,9 @@ def test_list_invalid_query_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - def _test_list_query_parameter(self, valid: str) -> None: + def _test_list_query_parameter(self, valid: str): """Helper used to test both valid=true and valid=false.""" # Create 2 valid and 2 invalid tokens. now = self.hs.get_clock().time_msec() @@ -816,17 +696,17 @@ def _test_list_query_parameter(self, valid: str) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] self.assertIn(token_info_1["token"], tokens) self.assertIn(token_info_2["token"], tokens) - def test_list_valid(self) -> None: + def test_list_valid(self): """Test listing just valid tokens.""" self._test_list_query_parameter(valid="true") - def test_list_invalid(self) -> None: + def test_list_invalid(self): """Test listing just invalid tokens.""" self._test_list_query_parameter(valid="false") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 22f9aa62346a..07077aff78d6 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import json import urllib.parse from http import HTTPStatus from typing import List, Optional @@ -18,15 +20,11 @@ from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor - import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest @@ -42,7 +40,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -68,7 +66,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -78,12 +76,12 @@ def test_requester_is_no_admin(self): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): """ - Check that unknown rooms/server return 200 + Check that unknown rooms/server return error 404. """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" @@ -94,11 +92,12 @@ def test_room_does_not_exist(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_room_is_not_valid(self): """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ url = "/_synapse/admin/v1/rooms/%s" % "invalidroom" @@ -109,7 +108,7 @@ def test_room_is_not_valid(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -119,15 +118,16 @@ def test_new_room_user_does_not_exist(self): """ Tests that the user ID must be from local server but it does not have to exist. """ + body = json.dumps({"new_room_user_id": "@unknown:test"}) channel = self.make_request( "DELETE", self.url, - content={"new_room_user_id": "@unknown:test"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -137,15 +137,16 @@ def test_new_room_user_is_not_local(self): """ Check that only local users can create new room to move members. """ + body = json.dumps({"new_room_user_id": "@not:exist.bla"}) channel = self.make_request( "DELETE", self.url, - content={"new_room_user_id": "@not:exist.bla"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -155,30 +156,32 @@ def test_block_is_not_bool(self): """ If parameter `block` is not boolean, return an error """ + body = json.dumps({"block": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content={"block": "NotBool"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): """ If parameter `purge` is not boolean, return an error """ + body = json.dumps({"purge": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content={"purge": "NotBool"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -195,14 +198,16 @@ def test_purge_room_and_block(self): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) + body = json.dumps({"block": True, "purge": True}) + channel = self.make_request( "DELETE", self.url.encode("ascii"), - content={"block": True, "purge": True}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -226,14 +231,16 @@ def test_purge_room_and_not_block(self): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) + body = json.dumps({"block": False, "purge": True}) + channel = self.make_request( "DELETE", self.url.encode("ascii"), - content={"block": False, "purge": True}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -258,14 +265,16 @@ def test_block_room_and_not_purge(self): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) + body = json.dumps({"block": True, "purge": False}) + channel = self.make_request( "DELETE", self.url.encode("ascii"), - content={"block": True, "purge": False}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -296,7 +305,9 @@ def test_block_unknown_room(self, purge: bool) -> None: ) # The room is now blocked. - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual( + HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] + ) self._is_blocked(room_id) def test_shutdown_room_consent(self): @@ -316,10 +327,7 @@ def test_shutdown_room_consent(self): # Assert that the user is getting consent error self.helper.send( - self.room_id, - body="foo", - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 ) # Test that room is not purged @@ -333,11 +341,11 @@ def test_shutdown_room_consent(self): channel = self.make_request( "DELETE", self.url, - {"new_room_user_id": self.admin_user}, + json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -363,10 +371,10 @@ def test_shutdown_room_block_peek(self): channel = self.make_request( "PUT", url.encode("ascii"), - {"history_visibility": "world_readable"}, + json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -379,11 +387,11 @@ def test_shutdown_room_block_peek(self): channel = self.make_request( "DELETE", self.url, - {"new_room_user_id": self.admin_user}, + json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -398,7 +406,7 @@ def test_shutdown_room_block_peek(self): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id, expect=True): """Assert that the room is blocked or not""" @@ -457,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -494,7 +502,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) def test_requester_is_no_admin(self, method: str, url: str): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -507,36 +515,27 @@ def test_requester_is_no_admin(self, method: str, url: str): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_room_does_not_exist(self): + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_room_does_not_exist(self, method: str, url: str): """ - Check that unknown rooms/server return 200 - - This is important, as it allows incomplete vestiges of rooms to be cleared up - even if the create event/etc is missing. + Check that unknown rooms/server return error 404. """ - room_id = "!unknown:test" - channel = self.make_request( - "DELETE", - f"/_synapse/admin/v2/rooms/{room_id}", - content={}, - access_token=self.admin_user_tok, - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertIn("delete_id", channel.json_body) - delete_id = channel.json_body["delete_id"] - - # get status channel = self.make_request( - "GET", - f"/_synapse/admin/v2/rooms/{room_id}/delete_status", + method, + url % "!unknown:test", + content={}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(1, len(channel.json_body["results"])) - self.assertEqual("complete", channel.json_body["results"][0]["status"]) - self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( [ @@ -546,7 +545,7 @@ def test_room_does_not_exist(self): ) def test_room_is_not_valid(self, method: str, url: str): """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ channel = self.make_request( @@ -855,10 +854,7 @@ def test_shutdown_room_consent(self): # Assert that the user is getting consent error self.helper.send( - self.room_id, - body="foo", - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 ) # Test that room is not purged @@ -955,7 +951,7 @@ def test_shutdown_room_block_peek(self): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" @@ -1073,12 +1069,12 @@ class RoomTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - def test_list_rooms(self) -> None: + def test_list_rooms(self): """Test that we can list rooms""" # Create 3 test rooms total_rooms = 3 @@ -1098,7 +1094,7 @@ def test_list_rooms(self) -> None: ) # Check request completed successfully - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -1142,7 +1138,7 @@ def test_list_rooms(self) -> None: # We shouldn't receive a next token here as there's no further rooms to show self.assertNotIn("next_batch", channel.json_body) - def test_list_rooms_pagination(self) -> None: + def test_list_rooms_pagination(self): """Test that we can get a full list of rooms through pagination""" # Create 5 test rooms total_rooms = 5 @@ -1182,7 +1178,7 @@ def test_list_rooms_pagination(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -1222,9 +1218,9 @@ def test_list_rooms_pagination(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) - def test_correct_room_attributes(self) -> None: + def test_correct_room_attributes(self): """Test the correct attributes for a room are returned""" # Create a test room room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1245,7 +1241,7 @@ def test_correct_room_attributes(self) -> None: {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1277,7 +1273,7 @@ def test_correct_room_attributes(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1305,7 +1301,7 @@ def test_correct_room_attributes(self) -> None: self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) - def test_room_list_sort_order(self) -> None: + def test_room_list_sort_order(self): """Test room list sort ordering. alphabetical name versus number of members, reversing the order, etc. """ @@ -1314,7 +1310,7 @@ def _order_test( order_type: str, expected_room_list: List[str], reverse: bool = False, - ) -> None: + ): """Request the list of rooms in a certain order. Assert that order is what we expect @@ -1332,7 +1328,7 @@ def _order_test( url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1443,7 +1439,7 @@ def _order_test( _order_test("state_events", [room_id_3, room_id_2, room_id_1]) _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - def test_search_term(self) -> None: + def test_search_term(self): """Test that searching for a room works correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1471,8 +1467,8 @@ def test_search_term(self) -> None: def _search_test( expected_room_id: Optional[str], search_term: str, - expected_http_code: int = HTTPStatus.OK, - ) -> None: + expected_http_code: int = 200, + ): """Search for a room and check that the returned room's id is a match Args: @@ -1489,7 +1485,7 @@ def _search_test( ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that rooms were returned @@ -1532,7 +1528,7 @@ def _search_test( _search_test(None, "foo") _search_test(None, "bar") - _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST) + _search_test(None, "", expected_http_code=400) # Test that the whole room id returns the room _search_test(room_id_1, room_id_1) @@ -1546,7 +1542,7 @@ def _search_test( # Test search local part of alias _search_test(room_id_1, "alias1") - def test_search_term_non_ascii(self) -> None: + def test_search_term_non_ascii(self): """Test that searching for a room with non-ASCII characters works correctly""" # Create test room @@ -1569,11 +1565,11 @@ def test_search_term_non_ascii(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) - def test_single_room(self) -> None: + def test_single_room(self): """Test that a single room can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1602,7 +1598,7 @@ def test_single_room(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) @@ -1624,7 +1620,7 @@ def test_single_room(self) -> None: self.assertEqual(room_id_1, channel.json_body["room_id"]) - def test_single_room_devices(self) -> None: + def test_single_room_devices(self): """Test that `joined_local_devices` can be requested correctly""" room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1634,7 +1630,7 @@ def test_single_room_devices(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) # Have another user join the room @@ -1648,7 +1644,7 @@ def test_single_room_devices(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) # leave room @@ -1660,10 +1656,10 @@ def test_single_room_devices(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) - def test_room_members(self) -> None: + def test_room_members(self): """Test that room members can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1691,7 +1687,7 @@ def test_room_members(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] @@ -1704,14 +1700,14 @@ def test_room_members(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] ) self.assertEqual(channel.json_body["total"], 3) - def test_room_state(self) -> None: + def test_room_state(self): """Test that room state can be requested correctly""" # Create two test rooms room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1722,15 +1718,13 @@ def test_room_state(self) -> None: url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("state", channel.json_body) # testing that the state events match is painful and not done here. We assume that # the create_room already does the right thing, so no need to verify that we got # the state events it created. - def _set_canonical_alias( - self, room_id: str, test_alias: str, admin_user_tok: str - ) -> None: + def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str): # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) channel = self.make_request( @@ -1739,7 +1733,7 @@ def _set_canonical_alias( {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1765,7 +1759,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, homeserver): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1780,117 +1774,124 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If a parameter is missing, return an error """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content={"unknown_parameter": "@unknown:test"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - def test_local_user_does_not_exist(self) -> None: + def test_local_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ + body = json.dumps({"user_id": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content={"user_id": "@unknown:test"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_remote_user(self) -> None: + def test_remote_user(self): """ Check that only local user can join rooms. """ + body = json.dumps({"user_id": "@not:exist.bla"}) channel = self.make_request( "POST", self.url, - content={"user_id": "@not:exist.bla"}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], ) - def test_room_does_not_exist(self) -> None: + def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return error 404. """ + body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) - def test_room_is_not_valid(self) -> None: + def test_room_is_not_valid(self): """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ + body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], ) - def test_join_public_room(self) -> None: + def test_join_public_room(self): """ Test joining a local user to a public room with "JoinRules.PUBLIC" """ + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1900,10 +1901,10 @@ def test_join_public_room(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_not_member(self) -> None: + def test_join_private_room_if_not_member(self): """ Test joining a local user to a private room with "JoinRules.INVITE" when server admin is not member of this room. @@ -1912,18 +1913,19 @@ def test_join_private_room_if_not_member(self) -> None: self.creator, tok=self.creator_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_join_private_room_if_member(self) -> None: + def test_join_private_room_if_member(self): """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is member of this room. @@ -1948,20 +1950,21 @@ def test_join_private_room_if_member(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. url = f"/_synapse/admin/v1/join/{private_room_id}" + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1971,10 +1974,10 @@ def test_join_private_room_if_member(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_owner(self) -> None: + def test_join_private_room_if_owner(self): """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is owner of this room. @@ -1983,15 +1986,16 @@ def test_join_private_room_if_owner(self) -> None: self.admin_user, tok=self.admin_user_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" + body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content={"user_id": self.second_user_id}, + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -2001,10 +2005,10 @@ def test_join_private_room_if_owner(self) -> None: "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_context_as_non_admin(self) -> None: + def test_context_as_non_admin(self): """ Test that, without being admin, one cannot use the context admin API """ @@ -2035,10 +2039,10 @@ def test_context_as_non_admin(self) -> None: % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEquals(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_context_as_admin(self) -> None: + def test_context_as_admin(self): """ Test that, as admin, we can find the context of an event without having joined the room. """ @@ -2065,7 +2069,7 @@ def test_context_as_admin(self) -> None: % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -2094,7 +2098,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, homeserver): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2111,7 +2115,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.public_room_id ) - def test_public_room(self) -> None: + def test_public_room(self): """Test that getting admin in a public room works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2124,7 +2128,7 @@ def test_public_room(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -2136,7 +2140,7 @@ def test_public_room(self) -> None: tok=self.admin_user_tok, ) - def test_private_room(self) -> None: + def test_private_room(self): """Test that getting admin in a private room works and we get invited.""" room_id = self.helper.create_room_as( self.creator, @@ -2151,7 +2155,7 @@ def test_private_room(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -2164,7 +2168,7 @@ def test_private_room(self) -> None: tok=self.admin_user_tok, ) - def test_other_user(self) -> None: + def test_other_user(self): """Test that giving admin in a public room works to a non-admin user works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2177,7 +2181,7 @@ def test_other_user(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -2189,7 +2193,7 @@ def test_other_user(self) -> None: tok=self.second_tok, ) - def test_not_enough_power(self) -> None: + def test_not_enough_power(self): """Test that we get a sensible error if there are no local room admins.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2211,11 +2215,11 @@ def test_not_enough_power(self) -> None: access_token=self.admin_user_tok, ) - # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins. + # We expect this to fail with a 400 as there are no room admins. # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", @@ -2229,7 +2233,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self._store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2244,8 +2248,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/_synapse/admin/v1/rooms/%s/block" @parameterized.expand([("PUT",), ("GET",)]) - def test_requester_is_no_admin(self, method: str) -> None: - """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error 403 is returned.""" channel = self.make_request( method, @@ -2258,8 +2262,8 @@ def test_requester_is_no_admin(self, method: str) -> None: self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand([("PUT",), ("GET",)]) - def test_room_is_not_valid(self, method: str) -> None: - """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" + def test_room_is_not_valid(self, method: str): + """Check that invalid room names, return an error 400.""" channel = self.make_request( method, @@ -2274,7 +2278,7 @@ def test_room_is_not_valid(self, method: str) -> None: channel.json_body["error"], ) - def test_block_is_not_valid(self) -> None: + def test_block_is_not_valid(self): """If parameter `block` is not valid, return an error.""" # `block` is not valid @@ -2309,7 +2313,7 @@ def test_block_is_not_valid(self) -> None: self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) - def test_block_room(self) -> None: + def test_block_room(self): """Test that block a room is successful.""" def _request_and_test_block_room(room_id: str) -> None: @@ -2333,7 +2337,7 @@ def _request_and_test_block_room(room_id: str) -> None: # unknown remote room _request_and_test_block_room("!unknown:remote") - def test_block_room_twice(self) -> None: + def test_block_room_twice(self): """Test that block a room that is already blocked is successful.""" self._is_blocked(self.room_id, expect=False) @@ -2348,7 +2352,7 @@ def test_block_room_twice(self) -> None: self.assertTrue(channel.json_body["block"]) self._is_blocked(self.room_id, expect=True) - def test_unblock_room(self) -> None: + def test_unblock_room(self): """Test that unblock a room is successful.""" def _request_and_test_unblock_room(room_id: str) -> None: @@ -2373,7 +2377,7 @@ def _request_and_test_unblock_room(room_id: str) -> None: # unknown remote room _request_and_test_unblock_room("!unknown:remote") - def test_unblock_room_twice(self) -> None: + def test_unblock_room_twice(self): """Test that unblock a room that is not blocked is successful.""" self._block_room(self.room_id) @@ -2388,7 +2392,7 @@ def test_unblock_room_twice(self) -> None: self.assertFalse(channel.json_body["block"]) self._is_blocked(self.room_id, expect=False) - def test_get_blocked_room(self) -> None: + def test_get_blocked_room(self): """Test get status of a blocked room""" def _request_blocked_room(room_id: str) -> None: @@ -2412,7 +2416,7 @@ def _request_blocked_room(room_id: str) -> None: # unknown remote room _request_blocked_room("!unknown:remote") - def test_get_unblocked_room(self) -> None: + def test_get_unblocked_room(self): """Test get status of a unblocked room""" def _request_unblocked_room(room_id: str) -> None: diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 3c59f5f766bc..fbceba325494 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus -from typing import List -from twisted.test.proto_helpers import MemoryReactor +from typing import List import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, room, sync -from synapse.server import HomeServer from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict -from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -37,7 +33,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -52,18 +48,14 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/_synapse/admin/v1/send_server_notice" - def test_no_auth(self) -> None: + def test_no_auth(self): """Try to send a server notice without authentication.""" channel = self.make_request("POST", self.url) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """If the user is not a server admin, an error is returned.""" channel = self.make_request( "POST", @@ -71,16 +63,12 @@ def test_requester_is_no_admin(self) -> None: access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_does_not_exist(self) -> None: - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + def test_user_does_not_exist(self): + """Tests that a lookup for a user that does not exist returns a 404""" channel = self.make_request( "POST", self.url, @@ -88,13 +76,13 @@ def test_user_does_not_exist(self) -> None: content={"user_id": "@unknown_person:test", "content": ""}, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_is_not_local(self) -> None: + def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ channel = self.make_request( "POST", @@ -106,13 +94,13 @@ def test_user_is_not_local(self) -> None: }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Server notices can only be sent to local users", channel.json_body["error"] ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """If parameters are invalid, an error is returned.""" # no content, no user @@ -122,7 +110,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) # no content @@ -133,7 +121,7 @@ def test_invalid_parameter(self) -> None: content={"user_id": self.other_user}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no body @@ -144,7 +132,7 @@ def test_invalid_parameter(self) -> None: content={"user_id": self.other_user, "content": ""}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'body' not in content", channel.json_body["error"]) @@ -156,11 +144,11 @@ def test_invalid_parameter(self) -> None: content={"user_id": self.other_user, "content": {"body": ""}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) - def test_server_notice_disabled(self) -> None: + def test_server_notice_disabled(self): """Tests that server returns error if server notice is disabled""" channel = self.make_request( "POST", @@ -172,14 +160,14 @@ def test_server_notice_disabled(self) -> None: }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Server notices are not enabled on this server", channel.json_body["error"] ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice(self) -> None: + def test_send_server_notice(self): """ Tests that sending two server notices is successfully, the server uses the same room and do not send messages twice. @@ -197,7 +185,7 @@ def test_send_server_notice(self) -> None: "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -228,7 +216,7 @@ def test_send_server_notice(self) -> None: "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has no new invites or memberships self._check_invite_and_join_status(self.other_user, 0, 1) @@ -243,7 +231,7 @@ def test_send_server_notice(self) -> None: self.assertEqual(messages[1]["sender"], "@notices:test") @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_leave_room(self) -> None: + def test_send_server_notice_leave_room(self): """ Tests that sending a server notices is successfully. The user leaves the room and the second message appears @@ -262,7 +250,7 @@ def test_send_server_notice_leave_room(self) -> None: "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -305,7 +293,7 @@ def test_send_server_notice_leave_room(self) -> None: "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -327,7 +315,7 @@ def test_send_server_notice_leave_room(self) -> None: self.assertNotEqual(first_room_id, second_room_id) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_delete_room(self) -> None: + def test_send_server_notice_delete_room(self): """ Tests that the user get server notice in a new room after the first server notice room was deleted. @@ -345,7 +333,7 @@ def test_send_server_notice_delete_room(self) -> None: "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -394,7 +382,7 @@ def test_send_server_notice_delete_room(self) -> None: "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -417,7 +405,7 @@ def test_send_server_notice_delete_room(self) -> None: def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> RoomsForUser: """Check invite and room membership status of a user. Args @@ -452,7 +440,7 @@ def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]: channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=token ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) # Get the messages room = channel.json_body["rooms"]["join"][room_id] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 7cb8ec57bad9..ece89a65ac28 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -12,17 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus -from typing import List, Optional -from twisted.test.proto_helpers import MemoryReactor +import json +from typing import Any, Dict, List, Optional import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.types import JsonDict -from synapse.util import Clock from tests import unittest from tests.test_utils import SMALL_PNG @@ -34,7 +30,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -45,38 +41,30 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/_synapse/admin/v1/statistics/users/media" - def test_no_auth(self) -> None: + def test_no_auth(self): """ Try to list users without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self) -> None: + def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( "GET", self.url, - {}, + json.dumps({}), access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self) -> None: + def test_invalid_parameter(self): """ If parameters are invalid, an error is returned. """ @@ -87,11 +75,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -101,11 +85,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -115,11 +95,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts @@ -129,11 +105,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts @@ -143,11 +115,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts @@ -157,11 +125,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term @@ -171,11 +135,7 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -185,14 +145,10 @@ def test_invalid_parameter(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_limit(self) -> None: + def test_limit(self): """ Testing list of media with limit """ @@ -204,13 +160,13 @@ def test_limit(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["users"]) - def test_from(self) -> None: + def test_from(self): """ Testing list of media with a defined starting point (from) """ @@ -222,13 +178,13 @@ def test_from(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["users"]) - def test_limit_and_from(self) -> None: + def test_limit_and_from(self): """ Testing list of media with a defined starting point and limit """ @@ -240,13 +196,13 @@ def test_limit_and_from(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["users"]), 10) self._check_fields(channel.json_body["users"]) - def test_next_token(self) -> None: + def test_next_token(self): """ Testing that `next_token` appears at the right place """ @@ -262,7 +218,7 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -275,7 +231,7 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -288,7 +244,7 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -301,12 +257,12 @@ def test_next_token(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_no_media(self) -> None: + def test_no_media(self): """ Tests that a normal lookup for statistics is successfully if users have no media created @@ -318,11 +274,11 @@ def test_no_media(self) -> None: access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) - def test_order_by(self) -> None: + def test_order_by(self): """ Testing order list with parameter `order_by` """ @@ -400,7 +356,7 @@ def test_order_by(self) -> None: "b", ) - def test_from_until_ts(self) -> None: + def test_from_until_ts(self): """ Testing filter by time with parameters `from_ts` and `until_ts` """ @@ -415,7 +371,7 @@ def test_from_until_ts(self) -> None: self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media @@ -425,7 +381,7 @@ def test_from_until_ts(self) -> None: self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) self._create_media(self.other_user_tok, 3) @@ -440,7 +396,7 @@ def test_from_until_ts(self) -> None: self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier @@ -449,10 +405,10 @@ def test_from_until_ts(self) -> None: self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) - def test_search_term(self) -> None: + def test_search_term(self): self._create_users_with_media(20, 1) # check without filter get all users @@ -461,7 +417,7 @@ def test_search_term(self) -> None: self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` @@ -470,7 +426,7 @@ def test_search_term(self) -> None: self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` @@ -479,7 +435,7 @@ def test_search_term(self) -> None: self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -489,10 +445,10 @@ def test_search_term(self) -> None: self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) - def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: + def _create_users_with_media(self, number_users: int, media_per_user: int): """ Create a number of users with a number of media Args: @@ -504,7 +460,7 @@ def _create_users_with_media(self, number_users: int, media_per_user: int) -> No user_tok = self.login("foo_user_%s" % i, "pass") self._create_media(user_tok, media_per_user) - def _create_media(self, user_token: str, number_media: int) -> None: + def _create_media(self, user_token: str, number_media: int): """ Create a number of media for a specific user Args: @@ -515,10 +471,10 @@ def _create_media(self, user_token: str, number_media: int) -> None: for _ in range(number_media): # Upload some media into the room self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK + upload_resource, SMALL_PNG, tok=user_token, expect_code=200 ) - def _check_fields(self, content: List[JsonDict]) -> None: + def _check_fields(self, content: List[Dict[str, Any]]): """Checks that all attributes are present in content Args: content: List that is checked for content @@ -531,7 +487,7 @@ def _check_fields(self, content: List[JsonDict]) -> None: def _order_test( self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None - ) -> None: + ): """Request the list of users in a certain order. Assert that order is what we expect Args: @@ -549,7 +505,7 @@ def _order_test( url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["user_id"] for row in channel.json_body["users"]] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4fedd5fd0851..5011e5456353 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -17,7 +17,6 @@ import os import urllib.parse from binascii import unhexlify -from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock, patch @@ -75,7 +74,7 @@ def test_disabled(self): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -107,7 +106,7 @@ def test_expired_nonce(self): body = {"nonce": nonce} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -115,7 +114,7 @@ def test_expired_nonce(self): channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -127,18 +126,18 @@ def test_register_incorrect_nonce(self): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): @@ -153,7 +152,7 @@ def test_register_correct_nonce(self): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, @@ -161,11 +160,11 @@ def test_register_correct_nonce(self): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -177,24 +176,24 @@ def test_nonce_reuse(self): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -215,7 +214,7 @@ def nonce(): # Must be an empty body present channel = self.make_request("POST", self.url, {}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -225,28 +224,28 @@ def nonce(): # Must be present channel = self.make_request("POST", self.url, {"nonce": nonce()}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -257,28 +256,28 @@ def nonce(): body = {"nonce": nonce(), "username": "a"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": "a", "password": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -294,7 +293,7 @@ def nonce(): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -308,22 +307,22 @@ def test_displayname(self): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob1", "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -332,22 +331,22 @@ def test_displayname(self): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob2", "displayname": None, "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -356,22 +355,22 @@ def test_displayname(self): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob3", "displayname": "", "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -379,22 +378,22 @@ def test_displayname(self): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob4", "displayname": "Bob's Name", "password": "abc123", - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -426,7 +425,7 @@ def test_register_mau_limit_reached(self): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac_str = want_mac.hexdigest() + want_mac = want_mac.hexdigest() body = { "nonce": nonce, @@ -434,11 +433,11 @@ def test_register_mau_limit_reached(self): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac_str, + "mac": want_mac, } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -462,7 +461,7 @@ def test_no_auth(self): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -474,7 +473,7 @@ def test_requester_is_no_admin(self): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): @@ -490,7 +489,7 @@ def test_all_users(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -504,7 +503,7 @@ def _search_test( expected_user_id: Optional[str], search_term: str, search_field: Optional[str] = "name", - expected_http_code: Optional[int] = HTTPStatus.OK, + expected_http_code: Optional[int] = 200, ): """Search for a user and check that the returned user's id is a match @@ -526,7 +525,7 @@ def _search_test( ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that users were returned @@ -587,7 +586,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -597,7 +596,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -607,7 +606,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid deactivated @@ -617,7 +616,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # unkown order_by @@ -627,7 +626,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -637,7 +636,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def test_limit(self): @@ -655,7 +654,7 @@ def test_limit(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -676,7 +675,7 @@ def test_from(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -697,7 +696,7 @@ def test_limit_and_from(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -720,7 +719,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -733,7 +732,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -746,7 +745,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -760,7 +759,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -863,14 +862,14 @@ def _order_test( url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["name"] for row in channel.json_body["users"]] self.assertEqual(expected_user_list, returned_order) self._check_fields(channel.json_body["users"]) - def _check_fields(self, content: List[JsonDict]): + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -937,7 +936,7 @@ def test_no_auth(self): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -948,7 +947,7 @@ def test_requester_is_not_admin(self): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -958,12 +957,12 @@ def test_requester_is_not_admin(self): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that deactivation for a user that does not exist returns a 404 """ channel = self.make_request( @@ -972,7 +971,7 @@ def test_user_does_not_exist(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_erase_is_not_bool(self): @@ -987,18 +986,18 @@ def test_erase_is_not_bool(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that deactivation for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" channel = self.make_request("POST", url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only deactivate local users", channel.json_body["error"]) def test_deactivate_user_erase_true(self): @@ -1013,7 +1012,7 @@ def test_deactivate_user_erase_true(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1028,7 +1027,7 @@ def test_deactivate_user_erase_true(self): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1037,7 +1036,7 @@ def test_deactivate_user_erase_true(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1058,7 +1057,7 @@ def test_deactivate_user_erase_false(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1073,7 +1072,7 @@ def test_deactivate_user_erase_false(self): content={"erase": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1082,7 +1081,7 @@ def test_deactivate_user_erase_false(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1112,7 +1111,7 @@ def test_deactivate_user_erase_true_no_profile(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1127,7 +1126,7 @@ def test_deactivate_user_erase_true_no_profile(self): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1136,7 +1135,7 @@ def test_deactivate_user_erase_true_no_profile(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1196,7 +1195,7 @@ def test_requester_is_no_admin(self): access_token=self.other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1206,12 +1205,12 @@ def test_requester_is_no_admin(self): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ channel = self.make_request( @@ -1220,7 +1219,7 @@ def test_user_does_not_exist(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1235,7 +1234,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"admin": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # deactivated not bool @@ -1245,7 +1244,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"deactivated": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not str @@ -1255,7 +1254,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"password": True}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not length @@ -1265,7 +1264,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"password": "x" * 513}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # user_type not valid @@ -1275,7 +1274,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"user_type": "new type"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # external_ids not valid @@ -1287,7 +1286,7 @@ def test_invalid_parameter(self): "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1296,7 +1295,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # threepids not valid @@ -1306,7 +1305,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1315,7 +1314,7 @@ def test_invalid_parameter(self): access_token=self.admin_user_tok, content={"threepids": {"address": "value"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_get_user(self): @@ -1328,7 +1327,7 @@ def test_get_user(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) @@ -1371,7 +1370,7 @@ def test_create_server_admin(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1434,7 +1433,7 @@ def test_create_user(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1462,9 +1461,9 @@ def test_create_user_mau_limit_reached_active_admin(self): # before limit of monthly active users is reached channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) - if channel.code != HTTPStatus.OK: + if channel.code != 200: raise HttpResponseException( - channel.code, channel.result["reason"], channel.json_body + channel.code, channel.result["reason"], channel.result["body"] ) # Set monthly active users to the limit @@ -1626,7 +1625,7 @@ def test_set_password(self): content={"password": "hahaha"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_set_displayname(self): @@ -1642,7 +1641,7 @@ def test_set_displayname(self): content={"displayname": "foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1653,7 +1652,7 @@ def test_set_displayname(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1675,7 +1674,7 @@ def test_set_threepid(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1701,7 +1700,7 @@ def test_set_threepid(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1717,7 +1716,7 @@ def test_set_threepid(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1733,7 +1732,7 @@ def test_set_threepid(self): access_token=self.admin_user_tok, content={"threepids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1760,7 +1759,7 @@ def test_set_duplicate_threepid(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1779,7 +1778,7 @@ def test_set_duplicate_threepid(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1801,7 +1800,7 @@ def test_set_duplicate_threepid(self): ) # other user has this two threepids - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1820,7 +1819,7 @@ def test_set_duplicate_threepid(self): url_first_user, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1849,7 +1848,7 @@ def test_set_external_id(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1881,7 +1880,7 @@ def test_set_external_id(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1900,7 +1899,7 @@ def test_set_external_id(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1919,7 +1918,7 @@ def test_set_external_id(self): access_token=self.admin_user_tok, content={"external_ids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) @@ -1948,7 +1947,7 @@ def test_set_duplicate_external_id(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1974,7 +1973,7 @@ def test_set_duplicate_external_id(self): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2006,7 +2005,7 @@ def test_set_duplicate_external_id(self): ) # must fail - self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body) + self.assertEqual(409, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("External id is already in use.", channel.json_body["error"]) @@ -2017,7 +2016,7 @@ def test_set_duplicate_external_id(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2035,7 +2034,7 @@ def test_set_duplicate_external_id(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2066,7 +2065,7 @@ def test_deactivate_user(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -2081,7 +2080,7 @@ def test_deactivate_user(self): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2097,7 +2096,7 @@ def test_deactivate_user(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2124,7 +2123,7 @@ def test_change_name_deactivate_user_user_directory(self): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -2140,7 +2139,7 @@ def test_change_name_deactivate_user_user_directory(self): content={"displayname": "Foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -2164,7 +2163,7 @@ def test_reactivate_user(self): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -2173,7 +2172,7 @@ def test_reactivate_user(self): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNotNone(channel.json_body["password_hash"]) @@ -2195,7 +2194,7 @@ def test_reactivate_user_localdb_disabled(self): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2205,7 +2204,7 @@ def test_reactivate_user_localdb_disabled(self): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2227,7 +2226,7 @@ def test_reactivate_user_password_disabled(self): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2237,7 +2236,7 @@ def test_reactivate_user_password_disabled(self): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2256,7 +2255,7 @@ def test_set_user_as_admin(self): content={"admin": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2267,7 +2266,7 @@ def test_set_user_as_admin(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2284,7 +2283,7 @@ def test_set_user_type(self): content={"user_type": UserTypes.SUPPORT}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2295,7 +2294,7 @@ def test_set_user_type(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2307,7 +2306,7 @@ def test_set_user_type(self): content={"user_type": None}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2318,7 +2317,7 @@ def test_set_user_type(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2348,7 +2347,7 @@ def test_accidental_deactivation_prevention(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -2361,7 +2360,7 @@ def test_accidental_deactivation_prevention(self): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -2370,7 +2369,7 @@ def test_accidental_deactivation_prevention(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -2395,7 +2394,7 @@ def _deactivate_user(self, user_id: str) -> None: access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) @@ -2446,7 +2445,7 @@ def test_no_auth(self): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2461,7 +2460,7 @@ def test_requester_is_no_admin(self): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -2475,7 +2474,7 @@ def test_user_does_not_exist(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2491,7 +2490,7 @@ def test_user_is_not_local(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2507,7 +2506,7 @@ def test_no_memberships(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2528,7 +2527,7 @@ def test_get_rooms(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) @@ -2575,7 +2574,7 @@ def test_get_rooms_with_nonlocal_user(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) @@ -2604,7 +2603,7 @@ def test_no_auth(self): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2619,12 +2618,12 @@ def test_requester_is_no_admin(self): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" channel = self.make_request( @@ -2633,12 +2632,12 @@ def test_user_does_not_exist(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" @@ -2648,7 +2647,7 @@ def test_user_is_not_local(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): @@ -2663,7 +2662,7 @@ def test_get_pushers(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) # Register the pusher @@ -2694,7 +2693,7 @@ def test_get_pushers(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) for p in channel.json_body["pushers"]: @@ -2733,7 +2732,7 @@ def test_no_auth(self, method: str): """Try to list media of an user without authentication.""" channel = self.make_request(method, self.url, {}) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) @@ -2747,12 +2746,12 @@ def test_requester_is_no_admin(self, method: str): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_does_not_exist(self, method: str): - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( method, @@ -2760,12 +2759,12 @@ def test_user_does_not_exist(self, method: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_is_not_local(self, method: str): - """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST""" + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( @@ -2774,7 +2773,7 @@ def test_user_is_not_local(self, method: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_limit_GET(self): @@ -2790,7 +2789,7 @@ def test_limit_GET(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -2809,7 +2808,7 @@ def test_limit_DELETE(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5) @@ -2826,7 +2825,7 @@ def test_from_GET(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -2845,7 +2844,7 @@ def test_from_DELETE(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15) @@ -2862,7 +2861,7 @@ def test_limit_and_from_GET(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) @@ -2881,7 +2880,7 @@ def test_limit_and_from_DELETE(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["deleted_media"]), 10) @@ -2895,7 +2894,7 @@ def test_invalid_parameter(self, method: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -2905,7 +2904,7 @@ def test_invalid_parameter(self, method: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit @@ -2915,7 +2914,7 @@ def test_invalid_parameter(self, method: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -2925,7 +2924,7 @@ def test_invalid_parameter(self, method: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): @@ -2948,7 +2947,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2961,7 +2960,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2974,7 +2973,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2988,7 +2987,7 @@ def test_next_token(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -3005,7 +3004,7 @@ def test_user_has_no_media_GET(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) @@ -3020,7 +3019,7 @@ def test_user_has_no_media_DELETE(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["deleted_media"])) @@ -3037,7 +3036,7 @@ def test_get_media(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["media"])) self.assertNotIn("next_token", channel.json_body) @@ -3063,7 +3062,7 @@ def test_delete_media(self): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["deleted_media"])) self.assertCountEqual(channel.json_body["deleted_media"], media_ids) @@ -3208,7 +3207,7 @@ def _create_media_and_access( # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK + upload_resource, image_data, user_token, filename, expect_code=200 ) # Extract media ID from the response @@ -3226,16 +3225,16 @@ def _create_media_and_access( ) self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}" + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) return media_id - def _check_fields(self, content: List[JsonDict]): + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -3275,7 +3274,7 @@ def _order_test( url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_media_list)) returned_order = [row["media_id"] for row in channel.json_body["media"]] @@ -3311,14 +3310,14 @@ def _get_token(self) -> str: channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): @@ -3327,7 +3326,7 @@ def test_not_admin(self): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) def test_send_event(self): """Test that sending event as a user works.""" @@ -3352,7 +3351,7 @@ def test_devices(self): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -3364,21 +3363,21 @@ def test_logout(self): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -3389,23 +3388,23 @@ def test_user_logout_all(self): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) def test_admin_logout_all(self): """Tests that the admin user calling `/logout/all` does expire the @@ -3416,23 +3415,23 @@ def test_admin_logout_all(self): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3460,10 +3459,7 @@ def test_consent(self): # Now unaccept it and check that we can't send an event self.get_success(self.store.user_set_consent_version(self.other_user, "0.0")) self.helper.send_event( - room_id, - "com.example.test", - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + room_id, "com.example.test", tok=self.other_user_tok, expect_code=403 ) # Login in as the user @@ -3481,10 +3477,7 @@ def test_mau_limit(self): # Trying to join as the other user should fail due to reaching MAU limit. self.helper.join( - room_id, - user=self.other_user, - tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403 ) # Logging in as the other user and joining a room should work, even @@ -3519,7 +3512,7 @@ def test_no_auth(self): Try to get information of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3534,12 +3527,12 @@ def test_requester_is_not_admin(self): self.url, access_token=other_user2_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = self.url_prefix % "@unknown_person:unknown_domain" @@ -3548,7 +3541,7 @@ def test_user_is_not_local(self): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) def test_get_whois_admin(self): @@ -3560,7 +3553,7 @@ def test_get_whois_admin(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3575,7 +3568,7 @@ def test_get_whois_user(self): self.url, access_token=other_user_token, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3605,7 +3598,7 @@ def test_no_auth(self, method: str): Try to get information of an user without authentication. """ channel = self.make_request(method, self.url) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) @@ -3616,18 +3609,18 @@ def test_requester_is_not_admin(self, method: str): other_user_token = self.login("user", "pass") channel = self.make_request(method, self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) def test_user_is_not_local(self, method: str): """ - Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that shadow-banning for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" channel = self.make_request(method, url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) def test_success(self): """ @@ -3639,7 +3632,7 @@ def test_success(self): self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is shadow-banned (and the cache was cleared). @@ -3650,7 +3643,7 @@ def test_success(self): channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is no longer shadow-banned (and the cache was cleared). @@ -3684,7 +3677,7 @@ def test_no_auth(self, method: str): """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) @@ -3700,13 +3693,13 @@ def test_requester_is_no_admin(self, method: str): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) def test_user_does_not_exist(self, method: str): """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" @@ -3716,7 +3709,7 @@ def test_user_does_not_exist(self, method: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( @@ -3728,7 +3721,7 @@ def test_user_does_not_exist(self, method: str): ) def test_user_is_not_local(self, method: str, error_msg: str): """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" @@ -3740,7 +3733,7 @@ def test_user_is_not_local(self, method: str, error_msg: str): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self): @@ -3755,7 +3748,7 @@ def test_invalid_parameter(self): content={"messages_per_second": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3766,7 +3759,7 @@ def test_invalid_parameter(self): content={"messages_per_second": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3777,7 +3770,7 @@ def test_invalid_parameter(self): content={"burst_count": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3788,7 +3781,7 @@ def test_invalid_parameter(self): content={"burst_count": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3813,7 +3806,7 @@ def test_return_zero_when_null(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3827,7 +3820,7 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3838,7 +3831,7 @@ def test_success(self): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3849,7 +3842,7 @@ def test_success(self): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3859,7 +3852,7 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3869,7 +3862,7 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3879,6 +3872,6 @@ def test_success(self): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 7978626e7197..4e1c49c28b8d 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus - import synapse.rest.admin from synapse.api.errors import Codes, SynapseError from synapse.rest.client import login @@ -35,38 +33,30 @@ def prepare(self, reactor, clock, hs): async def check_username(username): if username == "allowed": return True - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "User ID already taken.", - errcode=Codes.USER_IN_USE, - ) + raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) handler = self.hs.get_registration_handler() handler.check_username = check_username def test_username_available(self): """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "allowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertTrue(channel.json_body["available"]) def test_username_unavailable(self): """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "disallowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") self.assertEqual(channel.json_body["error"], "User ID already taken.") diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 72bbc87b4a0c..855267143138 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import Optional, Union from twisted.internet.defer import succeed @@ -514,39 +513,12 @@ def prepare(self, reactor, clock, hs): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) - def use_refresh_token(self, refresh_token: str) -> FakeChannel: - """ - Helper that makes a request to use a refresh token. - """ - return self.make_request( - "POST", - "/_matrix/client/v1/refresh", - {"refresh_token": refresh_token}, - ) - - def is_access_token_valid(self, access_token) -> bool: - """ - Checks whether an access token is valid, returning whether it is or not. - """ - code = self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code - - # Either 200 or 401 is what we get back; anything else is a bug. - assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} - - return code == HTTPStatus.OK - def test_login_issue_refresh_token(self): """ A login response should include a refresh_token only if asked. """ # Test login - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body @@ -556,8 +528,8 @@ def test_login_issue_refresh_token(self): login_with_refresh = self.make_request( "POST", - "/_matrix/client/r0/login", - {"refresh_token": True, **body}, + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, ) self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) self.assertIn("refresh_token", login_with_refresh.json_body) @@ -583,12 +555,11 @@ def test_register_issue_refresh_token(self): register_with_refresh = self.make_request( "POST", - "/_matrix/client/r0/register", + "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", { "username": "test3", "password": self.user_pass, "auth": {"type": LoginType.DUMMY}, - "refresh_token": True, }, ) self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) @@ -599,22 +570,17 @@ def test_token_refresh(self): """ A refresh token can be used to issue a new access token. """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_response = self.make_request( "POST", - "/_matrix/client/r0/login", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", body, ) self.assertEqual(login_response.code, 200, login_response.result) refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) @@ -633,19 +599,14 @@ def test_token_refresh(self): ) @override_config({"refreshable_access_token_lifetime": "1m"}) - def test_refreshable_access_token_expiration(self): + def test_refresh_token_expiration(self): """ The access token should have some time as specified in the config. """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_response = self.make_request( "POST", - "/_matrix/client/r0/login", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", body, ) self.assertEqual(login_response.code, 200, login_response.result) @@ -655,198 +616,13 @@ def test_refreshable_access_token_expiration(self): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) - access_token = refresh_response.json_body["access_token"] - - # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) - self.reactor.advance(59.0) - # Check that our token is valid - self.assertEqual( - self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code, - HTTPStatus.OK, - ) - - # Advance 2 more seconds (just past the time of expiry) - self.reactor.advance(2.0) - # Check that our token is invalid - self.assertEqual( - self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code, - HTTPStatus.UNAUTHORIZED, - ) - - @override_config( - { - "refreshable_access_token_lifetime": "1m", - "nonrefreshable_access_token_lifetime": "10m", - } - ) - def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): - """ - Tests that the expiry times for refreshable and non-refreshable access - tokens can be different. - """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - } - login_response1 = self.make_request( - "POST", - "/_matrix/client/r0/login", - {"refresh_token": True, **body}, - ) - self.assertEqual(login_response1.code, 200, login_response1.result) - self.assertApproximates( - login_response1.json_body["expires_in_ms"], 60 * 1000, 100 - ) - refreshable_access_token = login_response1.json_body["access_token"] - - login_response2 = self.make_request( - "POST", - "/_matrix/client/r0/login", - body, - ) - self.assertEqual(login_response2.code, 200, login_response2.result) - nonrefreshable_access_token = login_response2.json_body["access_token"] - - # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) - self.reactor.advance(59.0) - - # Both tokens should still be valid. - self.assertTrue(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) - - # Advance to 61 s (just past 1 minute, the time of expiry) - self.reactor.advance(2.0) - - # Only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) - - # Advance to 599 s (just shy of 10 minutes, the time of expiry) - self.reactor.advance(599.0 - 61.0) - - # It's still the case that only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) - - # Advance to 601 s (just past 10 minutes, the time of expiry) - self.reactor.advance(2.0) - - # Now neither token is valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) - - @override_config( - {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} - ) - def test_refresh_token_expiry(self): - """ - The refresh token can be configured to have a limited lifetime. - When that lifetime has ended, the refresh token can no longer be used to - refresh the session. - """ - - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login", - body, - ) - self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) - refresh_token1 = login_response.json_body["refresh_token"] - - # Advance 119 seconds in the future (just shy of 2 minutes) - self.reactor.advance(119.0) - - # Refresh our session. The refresh token should still JUST be valid right now. - # By doing so, we get a new access token and a new refresh token. - refresh_response = self.use_refresh_token(refresh_token1) - self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) - self.assertIn( - "refresh_token", - refresh_response.json_body, - "No new refresh token returned after refresh.", - ) - refresh_token2 = refresh_response.json_body["refresh_token"] - - # Advance 121 seconds in the future (just a bit more than 2 minutes) - self.reactor.advance(121.0) - - # Try to refresh our session, but instead notice that the refresh token is - # not valid (it just expired). - refresh_response = self.use_refresh_token(refresh_token2) - self.assertEqual( - refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result - ) - - @override_config( - { - "refreshable_access_token_lifetime": "2m", - "refresh_token_lifetime": "2m", - "session_lifetime": "3m", - } - ) - def test_ultimate_session_expiry(self): - """ - The session can be configured to have an ultimate, limited lifetime. - """ - - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - refresh_token = login_response.json_body["refresh_token"] - - # Advance shy of 2 minutes into the future - self.reactor.advance(119.0) - - # Refresh our session. The refresh token should still be valid right now. - refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertIn( - "refresh_token", - refresh_response.json_body, - "No new refresh token returned after refresh.", - ) - # Notice that our access token lifetime has been diminished to match the - # session lifetime. - # 3 minutes - 119 seconds = 61 seconds. - self.assertEqual(refresh_response.json_body["expires_in_ms"], 61_000) - refresh_token = refresh_response.json_body["refresh_token"] - - # Advance 61 seconds into the future. Our session should have expired - # now, because we've had our 3 minutes. - self.reactor.advance(61.0) - - # Try to issue a new, refreshed, access token. - # This should fail because the refresh token's lifetime has also been - # diminished as our session expired. - refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 403, refresh_response.result) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -864,15 +640,10 @@ def test_refresh_token_invalidation(self): |-> fourth_refresh (fails) """ - body = { - "type": "m.login.password", - "user": "test", - "password": self.user_pass, - "refresh_token": True, - } + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} login_response = self.make_request( "POST", - "/_matrix/client/r0/login", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", body, ) self.assertEqual(login_response.code, 200, login_response.result) @@ -880,7 +651,7 @@ def test_refresh_token_invalidation(self): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -890,7 +661,7 @@ def test_refresh_token_invalidation(self): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -900,7 +671,7 @@ def test_refresh_token_invalidation(self): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -928,7 +699,7 @@ def test_refresh_token_invalidation(self): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -938,7 +709,7 @@ def test_refresh_token_invalidation(self): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 397c12c2a6c5..eb10d4321793 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -19,7 +19,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client import login, register, relations, room, sync +from synapse.rest.client import login, register, relations, room from tests import unittest from tests.server import FakeChannel @@ -29,7 +29,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, - sync.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, @@ -455,9 +454,11 @@ def test_aggregation_must_be_annotation(self): self.assertEquals(400, channel.code, channel.json_body) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_bundled_aggregations(self): - """Test that annotations, references, and threads get correctly bundled.""" - # Setup by sending a variety of relations. + def test_aggregation_get_event(self): + """Test that annotations, references, and threads get correctly bundled when + getting the parent event. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -484,169 +485,43 @@ def test_bundled_aggregations(self): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - def assert_bundle(actual): - """Assert the expected values of the bundled aggregations.""" - - # Ensure the fields are as expected. - self.assertCountEqual( - actual.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEquals( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - actual[RelationTypes.ANNOTATION], - ) - - self.assertEquals( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - actual[RelationTypes.REFERENCE], - ) - - self.assertEquals( - 2, - actual[RelationTypes.THREAD].get("count"), - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "user_id": self.user_id, - }, - actual[RelationTypes.THREAD].get("latest_event"), - ) - - def _find_and_assert_event(events): - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - break - else: - raise AssertionError(f"Event {self.parent_id} not found in chunk") - assert_bundle(event["unsigned"].get("m.relations")) - - # Request the event directly. channel = self.make_request( "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["unsigned"].get("m.relations")) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - _find_and_assert_event(channel.json_body["chunk"]) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - _find_and_assert_event(room_timeline["events"]) - - # Note that /relations is tested separately in test_aggregation_get_event_for_thread - # since it needs different data configured. - - def test_aggregation_get_event_for_annotation(self): - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", + "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - def test_aggregation_get_event_for_thread(self): - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEquals( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + RelationTypes.THREAD: { + "count": 2, + "latest_event": { + "age": 100, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "origin_server_ts": 1600, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "unsigned": {"age": 100}, + "user_id": self.user_id, + }, }, }, ) @@ -797,56 +672,6 @@ def test_edit_reply(self): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - def test_edit_edit(self): - """Test that an edit cannot be edited.""" - new_body = {"msgtype": "m.text", "body": "Initial edit"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": new_body, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] - - # Edit the edit event. - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "foo", - "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, - }, - parent_id=edit_event_id, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Request the original event. - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - # The edit to the edit should be ignored. - self.assertEquals(channel.json_body["content"], new_body) - - # The relations information should not include the edit to the edit. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 913bc530aac1..8fe94f7d853c 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import os from typing import Iterable -from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check +from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest @@ -487,109 +486,3 @@ def _test_path_validation( f"{value!r} unexpectedly passed validation: " f"{method} returned {path_or_list!r}" ) - - -class MediaFilePathsJailTestCase(unittest.TestCase): - def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None: - """Passes a relative path through the jail check. - - Args: - filepaths: The `MediaFilePaths` instance. - path: A path relative to the media store directory. - - Raises: - ValueError: If the jail check fails. - """ - - @_wrap_with_jail_check(relative=True) - def _make_relative_path(self: MediaFilePaths, path: str) -> str: - return path - - _make_relative_path(filepaths, path) - - def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None: - """Passes an absolute path through the jail check. - - Args: - filepaths: The `MediaFilePaths` instance. - path: A path relative to the media store directory. - - Raises: - ValueError: If the jail check fails. - """ - - @_wrap_with_jail_check(relative=False) - def _make_absolute_path(self: MediaFilePaths, path: str) -> str: - return os.path.join(self.base_path, path) - - _make_absolute_path(filepaths, path) - - def test_traversal_inside(self) -> None: - """Test the jail check for paths that stay within the media directory.""" - # Despite the `../`s, these paths still lie within the media directory and it's - # expected for the jail check to allow them through. - # These paths ought to trip the other checks in place and should never be - # returned. - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar" - self._check_relative_path(filepaths, path) - self._check_absolute_path(filepaths, path) - - def test_traversal_outside(self) -> None: - """Test that the jail check fails for paths that escape the media directory.""" - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar" - with self.assertRaises(ValueError): - self._check_relative_path(filepaths, path) - with self.assertRaises(ValueError): - self._check_absolute_path(filepaths, path) - - def test_traversal_reentry(self) -> None: - """Test the jail check for paths that exit and re-enter the media directory.""" - # These paths lie outside the media directory if it is a symlink, and inside - # otherwise. Ideally the check should fail, but this proves difficult. - # This test documents the behaviour for this edge case. - # These paths ought to trip the other checks in place and should never be - # returned. - filepaths = MediaFilePaths("/media_store") - path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar" - self._check_relative_path(filepaths, path) - self._check_absolute_path(filepaths, path) - - def test_symlink(self) -> None: - """Test that a symlink does not cause the jail check to fail.""" - media_store_path = self.mktemp() - - # symlink the media store directory - os.symlink("/mnt/synapse/media_store", media_store_path) - - # Test that relative and absolute paths don't trip the check - # NB: `media_store_path` is a relative path - filepaths = MediaFilePaths(media_store_path) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - filepaths = MediaFilePaths(os.path.abspath(media_store_path)) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - def test_symlink_subdirectory(self) -> None: - """Test that a symlinked subdirectory does not cause the jail check to fail.""" - media_store_path = self.mktemp() - os.mkdir(media_store_path) - - # symlink `url_cache/` - os.symlink( - "/mnt/synapse/media_store_url_cache", - os.path.join(media_store_path, "url_cache"), - ) - - # Test that relative and absolute paths don't trip the check - # NB: `media_store_path` is a relative path - filepaths = MediaFilePaths(media_store_path) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - - filepaths = MediaFilePaths(os.path.abspath(media_store_path)) - self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") - self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 5ae491ff5a52..a649e8c61872 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -12,24 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from contextlib import contextmanager -from typing import Generator -from twisted.enterprise.adbapi import ConnectionPool -from twisted.internet.defer import ensureDeferred -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.room_versions import EventFormatVersions, RoomVersions from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room -from synapse.server import HomeServer -from synapse.storage.databases.main.events_worker import ( - EVENT_QUEUE_THREADS, - EventsWorkerStore, -) -from synapse.storage.types import Connection -from synapse.util import Clock +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -157,127 +144,3 @@ def test_dedupe(self): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) - - -class DatabaseOutageTestCase(unittest.HomeserverTestCase): - """Test event fetching during a database outage.""" - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): - self.store: EventsWorkerStore = hs.get_datastore() - - self.room_id = f"!room:{hs.hostname}" - self.event_ids = [f"event{i}" for i in range(20)] - - self._populate_events() - - def _populate_events(self) -> None: - """Ensure that there are test events in the database. - - When testing with the in-memory SQLite database, all the events are lost during - the simulated outage. - - To ensure consistency between `room_id`s and `event_id`s before and after the - outage, rows are built and inserted manually. - - Upserts are used to handle the non-SQLite case where events are not lost. - """ - self.get_success( - self.store.db_pool.simple_upsert( - "rooms", - {"room_id": self.room_id}, - {"room_version": RoomVersions.V4.identifier}, - ) - ) - - self.event_ids = [f"event{i}" for i in range(20)] - for idx, event_id in enumerate(self.event_ids): - self.get_success( - self.store.db_pool.simple_upsert( - "events", - {"event_id": event_id}, - { - "event_id": event_id, - "room_id": self.room_id, - "topological_ordering": idx, - "stream_ordering": idx, - "type": "test", - "processed": True, - "outlier": False, - }, - ) - ) - self.get_success( - self.store.db_pool.simple_upsert( - "event_json", - {"event_id": event_id}, - { - "room_id": self.room_id, - "json": json.dumps({"type": "test", "room_id": self.room_id}), - "internal_metadata": "{}", - "format_version": EventFormatVersions.V3, - }, - ) - ) - - @contextmanager - def _outage(self) -> Generator[None, None, None]: - """Simulate a database outage. - - Returns: - A context manager. While the context is active, any attempts to connect to - the database will fail. - """ - connection_pool = self.store.db_pool._db_pool - - # Close all connections and shut down the database `ThreadPool`. - connection_pool.close() - - # Restart the database `ThreadPool`. - connection_pool.start() - - original_connection_factory = connection_pool.connectionFactory - - def connection_factory(_pool: ConnectionPool) -> Connection: - raise Exception("Could not connect to the database.") - - connection_pool.connectionFactory = connection_factory # type: ignore[assignment] - try: - yield - finally: - connection_pool.connectionFactory = original_connection_factory - - # If the in-memory SQLite database is being used, all the events are gone. - # Restore the test data. - self._populate_events() - - def test_failure(self) -> None: - """Test that event fetches do not get stuck during a database outage.""" - with self._outage(): - failure = self.get_failure( - self.store.get_event(self.event_ids[0]), Exception - ) - self.assertEqual(str(failure.value), "Could not connect to the database.") - - def test_recovery(self) -> None: - """Test that event fetchers recover after a database outage.""" - with self._outage(): - # Kick off a bunch of event fetches but do not pump the reactor - event_deferreds = [] - for event_id in self.event_ids: - event_deferreds.append(ensureDeferred(self.store.get_event(event_id))) - - # We should have maxed out on event fetcher threads - self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS) - - # All the event fetchers will fail - self.pump() - self.assertEqual(self.store._event_fetch_ongoing, 0) - - for event_deferred in event_deferreds: - failure = self.get_failure(event_deferred, Exception) - self.assertEqual( - str(failure.value), "Could not connect to the database." - ) - - # This next event fetch should succeed - self.get_success(self.store.get_event(self.event_ids[0])) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 329490caad53..f26d5acf9c29 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -14,37 +14,35 @@ import json import os import tempfile -from typing import List, Optional, cast from unittest.mock import Mock import yaml from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.events import EventBase -from synapse.server import HomeServer from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable +from tests.utils import setup_test_homeserver -class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): +class ApplicationServiceStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks def setUp(self): - super(ApplicationServiceStoreTestCase, self).setUp() - - self.as_yaml_files: List[str] = [] + self.as_yaml_files = [] + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) - self.hs.config.appservice.app_service_config_files = self.as_yaml_files - self.hs.config.caches.event_cache_size = 1 + hs.config.appservice.app_service_config_files = self.as_yaml_files + hs.config.caches.event_cache_size = 1 self.as_token = "token1" self.as_url = "some_url" @@ -55,14 +53,12 @@ def setUp(self): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] self.store = ApplicationServiceStore( - database, - make_conn(database._database_config, database.engine, "test"), - self.hs, + database, make_conn(database._database_config, database.engine, "test"), hs ) - def tearDown(self) -> None: + def tearDown(self): # TODO: suboptimal that we need to create files for tests! for f in self.as_yaml_files: try: @@ -70,9 +66,7 @@ def tearDown(self) -> None: except Exception: pass - super(ApplicationServiceStoreTestCase, self).tearDown() - - def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: + def _add_appservice(self, as_token, id, url, hs_token, sender): as_yaml = { "url": url, "as_token": as_token, @@ -86,13 +80,12 @@ def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def test_retrieve_unknown_service_token(self) -> None: + def test_retrieve_unknown_service_token(self): service = self.store.get_app_service_by_token("invalid_token") self.assertEquals(service, None) - def test_retrieval_of_service(self) -> None: + def test_retrieval_of_service(self): stored_service = self.store.get_app_service_by_token(self.as_token) - assert stored_service is not None self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) @@ -100,18 +93,22 @@ def test_retrieval_of_service(self) -> None: self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) - def test_retrieval_of_all_services(self) -> None: + def test_retrieval_of_all_services(self): services = self.store.get_app_services() self.assertEquals(len(services), 3) -class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): - def setUp(self) -> None: - super(ApplicationServiceTransactionStoreTestCase, self).setUp() - self.as_yaml_files: List[str] = [] +class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.as_yaml_files = [] + + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) - self.hs.config.appservice.app_service_config_files = self.as_yaml_files - self.hs.config.caches.event_cache_size = 1 + hs.config.appservice.app_service_config_files = self.as_yaml_files + hs.config.caches.event_cache_size = 1 self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -120,21 +117,21 @@ def setUp(self) -> None: {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"}, ] for s in self.as_list: - self._add_service(s["url"], s["token"], s["id"]) + yield self._add_service(s["url"], s["token"], s["id"]) self.as_yaml_files = [] # We assume there is only one database in these tests - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] self.db_pool = database._db_pool self.engine = database.engine - db_config = self.hs.config.database.get_single_database() + db_config = hs.config.database.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine, "test"), self.hs + database, make_conn(db_config, self.engine, "test"), hs ) - def _add_service(self, url, as_token, id) -> None: + def _add_service(self, url, as_token, id): as_yaml = { "url": url, "as_token": as_token, @@ -148,15 +145,13 @@ def _add_service(self, url, as_token, id) -> None: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def _set_state( - self, id: str, state: ApplicationServiceState, txn: Optional[int] = None - ): + def _set_state(self, id, state, txn=None): return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_state(as_id, state, last_txn) " "VALUES(?,?,?)" ), - (id, state.value, txn), + (id, state, txn), ) def _insert_txn(self, as_id, txn_id, events): @@ -174,277 +169,234 @@ def _set_last_txn(self, as_id, txn_id): "INSERT INTO application_services_state(as_id, last_txn, state) " "VALUES(?,?,?)" ), - (as_id, txn_id, ApplicationServiceState.UP.value), + (as_id, txn_id, ApplicationServiceState.UP), ) - def test_get_appservice_state_none( - self, - ) -> None: + @defer.inlineCallbacks + def test_get_appservice_state_none(self): service = Mock(id="999") - state = self.get_success(self.store.get_appservice_state(service)) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) - def test_get_appservice_state_up( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) - ) + @defer.inlineCallbacks + def test_get_appservice_state_up(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = self.get_success( - defer.ensureDeferred(self.store.get_appservice_state(service)) - ) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) - def test_get_appservice_state_down( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) - ) - self.get_success( - self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) - ) + @defer.inlineCallbacks + def test_get_appservice_state_down(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = self.get_success(self.store.get_appservice_state(service)) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) - def test_get_appservices_by_state_none( - self, - ) -> None: - services = self.get_success( + @defer.inlineCallbacks + def test_get_appservices_by_state_none(self): + services = yield defer.ensureDeferred( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) - def test_set_appservices_state_down( - self, - ) -> None: + @defer.inlineCallbacks + def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - rows = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.DOWN.value,), - ) + rows = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.DOWN,), ) self.assertEquals(service.id, rows[0][0]) - def test_set_appservices_state_multiple_up( - self, - ) -> None: + @defer.inlineCallbacks + def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - self.get_success( + yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - rows = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.UP.value,), - ) + rows = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.UP,), ) self.assertEquals(service.id, rows[0][0]) - def test_create_appservice_txn_first( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_first(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - txn = self.get_success( - defer.ensureDeferred(self.store.create_appservice_txn(service, events, [])) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_create_appservice_txn_older_last_txn( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_older_last_txn(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind - self.get_success(self._insert_txn(service.id, 9644, events)) - self.get_success(self._insert_txn(service.id, 9645, events)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + yield self._set_last_txn(service.id, 9643) # AS is falling behind + yield self._insert_txn(service.id, 9644, events) + yield self._insert_txn(service.id, 9645, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) + ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_create_appservice_txn_up_to_date_last_txn( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_up_to_date_last_txn(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + yield self._set_last_txn(service.id, 9643) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_create_appservice_txn_up_fuzzing( - self, - ) -> None: + @defer.inlineCallbacks + def test_create_appservice_txn_up_fuzzing(self): service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) + events = [Mock(event_id="e1"), Mock(event_id="e2")] + yield self._set_last_txn(service.id, 9643) # dump in rows with higher IDs to make sure the queries aren't wrong. - self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643)) - self.get_success(self._set_last_txn(self.as_list[2]["id"], 9)) - self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events)) - self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) - self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) - - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + yield self._set_last_txn(self.as_list[1]["id"], 119643) + yield self._set_last_txn(self.as_list[2]["id"], 9) + yield self._set_last_txn(self.as_list[3]["id"], 9643) + yield self._insert_txn(self.as_list[1]["id"], 119644, events) + yield self._insert_txn(self.as_list[1]["id"], 119645, events) + yield self._insert_txn(self.as_list[1]["id"], 119646, events) + yield self._insert_txn(self.as_list[2]["id"], 10, events) + yield self._insert_txn(self.as_list[3]["id"], 9643, events) + + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events, []) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - def test_complete_appservice_txn_first_txn( - self, - ) -> None: + @defer.inlineCallbacks + def test_complete_appservice_txn_first_txn(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 1 - self.get_success(self._insert_txn(service.id, txn_id, events)) - self.get_success( + yield self._insert_txn(service.id, txn_id, events) + yield defer.ensureDeferred( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn FROM application_services_state WHERE as_id=?" - ), - (service.id,), - ) + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn FROM application_services_state WHERE as_id=?" + ), + (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), - ) + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) - def test_complete_appservice_txn_existing_in_state_table( - self, - ) -> None: + @defer.inlineCallbacks + def test_complete_appservice_txn_existing_in_state_table(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 5 - self.get_success(self._set_last_txn(service.id, 4)) - self.get_success(self._insert_txn(service.id, txn_id, events)) - self.get_success( + yield self._set_last_txn(service.id, 4) + yield self._insert_txn(service.id, txn_id, events) + yield defer.ensureDeferred( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn, state FROM application_services_state WHERE as_id=?" - ), - (service.id,), - ) + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + ), + (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) - - res = self.get_success( - self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), - ) + self.assertEquals(ApplicationServiceState.UP, res[0][1]) + + res = yield self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) - def test_get_oldest_unsent_txn_none( - self, - ) -> None: + @defer.inlineCallbacks + def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = self.get_success(self.store.get_oldest_unsent_txn(service)) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) - def test_get_oldest_unsent_txn(self) -> None: + @defer.inlineCallbacks + def test_get_oldest_unsent_txn(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - # (ignore needed because Mypy won't allow us to assign to a method otherwise) - self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) - self.get_success(self._insert_txn(service.id, 10, events)) - self.get_success(self._insert_txn(service.id, 11, other_events)) - self.get_success(self._insert_txn(service.id, 12, other_events)) + yield self._insert_txn(self.as_list[1]["id"], 9, other_events) + yield self._insert_txn(service.id, 10, events) + yield self._insert_txn(service.id, 11, other_events) + yield self._insert_txn(service.id, 12, other_events) - txn = self.get_success(self.store.get_oldest_unsent_txn(service)) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) - def test_get_appservices_by_state_single( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - ) + @defer.inlineCallbacks + def test_get_appservices_by_state_single(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = self.get_success( + services = yield defer.ensureDeferred( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) - def test_get_appservices_by_state_multiple( - self, - ) -> None: - self.get_success( - self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - ) - self.get_success( - self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) - ) - self.get_success( - self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - ) + @defer.inlineCallbacks + def test_get_appservices_by_state_multiple(self): + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = self.get_success( + services = yield defer.ensureDeferred( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) @@ -455,16 +407,16 @@ def test_get_appservices_by_state_multiple( class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): - def prepare( - self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer - ) -> None: + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): self.service = Mock(id="foo") self.store = self.hs.get_datastore() - self.get_success( - self.store.set_appservice_state(self.service, ApplicationServiceState.UP) - ) + self.get_success(self.store.set_appservice_state(self.service, "up")) - def test_get_type_stream_id_for_appservice_no_value(self) -> None: + def test_get_type_stream_id_for_appservice_no_value(self): value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) @@ -475,13 +427,13 @@ def test_get_type_stream_id_for_appservice_no_value(self) -> None: ) self.assertEquals(value, 0) - def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: + def test_get_type_stream_id_for_appservice_invalid_type(self): self.get_failure( self.store.get_type_stream_id_for_appservice(self.service, "foobar"), ValueError, ) - def test_set_type_stream_id_for_appservice(self) -> None: + def test_set_type_stream_id_for_appservice(self): read_receipt_value = 1024 self.get_success( self.store.set_type_stream_id_for_appservice( @@ -503,7 +455,7 @@ def test_set_type_stream_id_for_appservice(self) -> None: ) self.assertEqual(result, read_receipt_value) - def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: + def test_set_type_stream_id_for_appservice_invalid_type(self): self.get_failure( self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), ValueError, @@ -512,12 +464,12 @@ def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: DatabasePool, db_conn, hs) -> None: + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) -class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): - def _write_config(self, suffix, **kwargs) -> str: +class ApplicationServiceStoreConfigTestCase(unittest.TestCase): + def _write_config(self, suffix, **kwargs): vals = { "id": "id" + suffix, "url": "url" + suffix, @@ -533,33 +485,41 @@ def _write_config(self, suffix, **kwargs) -> str: f.write(yaml.dump(vals)) return path - def test_unique_works(self) -> None: + @defer.inlineCallbacks + def test_unique_works(self): f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - self.hs.config.appservice.app_service_config_files = [f1, f2] - self.hs.config.caches.event_cache_size = 1 + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) + + hs.config.appservice.app_service_config_files = [f1, f2] + hs.config.caches.event_cache_size = 1 - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, - make_conn(database._database_config, database.engine, "test"), - self.hs, + database, make_conn(database._database_config, database.engine, "test"), hs ) - def test_duplicate_ids(self) -> None: + @defer.inlineCallbacks + def test_duplicate_ids(self): f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - self.hs.config.appservice.app_service_config_files = [f1, f2] - self.hs.config.caches.event_cache_size = 1 + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) + + hs.config.appservice.app_service_config_files = [f1, f2] + hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - self.hs, + hs, ) e = cm.exception @@ -567,19 +527,24 @@ def test_duplicate_ids(self) -> None: self.assertIn(f2, str(e)) self.assertIn("id", str(e)) - def test_duplicate_as_tokens(self) -> None: + @defer.inlineCallbacks + def test_duplicate_as_tokens(self): f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - self.hs.config.appservice.app_service_config_files = [f1, f2] - self.hs.config.caches.event_cache_size = 1 + hs = yield setup_test_homeserver( + self.addCleanup, federation_sender=Mock(), federation_client=Mock() + ) + + hs.config.appservice.app_service_config_files = [f1, f2] + hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = self.hs.get_datastores().databases[0] + database = hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - self.hs, + hs, ) e = cm.exception diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index d77c001506c6..a5f5ebad410f 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,26 +1,8 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Use backported mock for AsyncMock support on Python 3.6. -from mock import Mock - -from twisted.internet.defer import Deferred, ensureDeferred +from unittest.mock import Mock from synapse.storage.background_updates import BackgroundUpdater from tests import unittest -from tests.test_utils import make_awaitable class BackgroundUpdateTestCase(unittest.HomeserverTestCase): @@ -38,10 +20,10 @@ def prepare(self, reactor, clock, homeserver): def test_do_background_update(self): # the time we claim it takes to update one item when running the update - duration_ms = 10 + duration_ms = 4200 # the target runtime for each bg update - target_background_update_duration_ms = 100 + target_background_update_duration_ms = 5000000 store = self.hs.get_datastore() self.get_success( @@ -66,8 +48,10 @@ async def update(progress, count): self.update_handler.side_effect = update self.update_handler.reset_mock() res = self.get_success( - self.updates.do_next_background_update(False), - by=0.01, + self.updates.do_next_background_update( + target_background_update_duration_ms + ), + by=0.1, ) self.assertFalse(res) @@ -90,93 +74,16 @@ async def update(progress, count): self.update_handler.side_effect = update self.update_handler.reset_mock() - result = self.get_success(self.updates.do_next_background_update(False)) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = self.get_success(self.updates.do_next_background_update(False)) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) self.assertTrue(result) self.assertFalse(self.update_handler.called) - - -class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates - # the base test class should have run the real bg updates for us - self.assertTrue( - self.get_success(self.updates.has_completed_background_updates()) - ) - - self.update_deferred = Deferred() - self.update_handler = Mock(return_value=self.update_deferred) - self.updates.register_background_update_handler( - "test_update", self.update_handler - ) - - # Mock out the AsyncContextManager - self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"]) - self._update_ctx_manager.__aenter__ = Mock( - return_value=make_awaitable(None), - ) - self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None)) - - # Mock out the `update_handler` callback - self._on_update = Mock(return_value=self._update_ctx_manager) - - # Define a default batch size value that's not the same as the internal default - # value (100). - self._default_batch_size = 500 - - # Register the callbacks with more mocks - self.hs.get_module_api().register_background_update_controller_callbacks( - on_update=self._on_update, - min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), - default_batch_size=Mock( - return_value=make_awaitable(self._default_batch_size), - ), - ) - - def test_controller(self): - store = self.hs.get_datastore() - self.get_success( - store.db_pool.simple_insert( - "background_updates", - values={"update_name": "test_update", "progress_json": "{}"}, - ) - ) - - # Set the return value for the context manager. - enter_defer = Deferred() - self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) - - # Start the background update. - do_update_d = ensureDeferred(self.updates.do_next_background_update(True)) - - self.pump() - - # `run_update` should have been called, but the update handler won't be - # called until the `enter_defer` (returned by `__aenter__`) is resolved. - self._on_update.assert_called_once_with( - "test_update", - "master", - False, - ) - self.assertFalse(do_update_d.called) - self.assertFalse(self.update_deferred.called) - - # Resolving the `enter_defer` should call the update handler, which then - # blocks. - enter_defer.callback(100) - self.pump() - self.update_handler.assert_called_once_with({}, self._default_batch_size) - self.assertFalse(self.update_deferred.called) - self._update_ctx_manager.__aexit__.assert_not_called() - - # Resolving the update handler deferred should cause the - # `do_next_background_update` to finish and return - self.update_deferred.callback(100) - self.pump() - self._update_ctx_manager.__aexit__.assert_called() - self.get_success(do_update_d) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 7b7f6c349e1c..b31c5eb5ecc6 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -664,7 +664,7 @@ def test_background_update_single_large_room(self): ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Ensure that we did actually take multiple iterations to process the @@ -723,7 +723,7 @@ def test_background_update_multiple_large_room(self): ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Ensure that we did actually take multiple iterations to process the diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index f8d11bac4ec5..d2b7b8995200 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -13,35 +13,42 @@ # limitations under the License. +from twisted.internet import defer + from synapse.types import UserID from tests import unittest +from tests.utils import setup_test_homeserver -class DataStoreTestCase(unittest.HomeserverTestCase): - def setUp(self) -> None: - super(DataStoreTestCase, self).setUp() +class DataStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver(self.addCleanup) - self.store = self.hs.get_datastore() + self.store = hs.get_datastore() self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" - def test_get_users_paginate(self) -> None: - self.get_success(self.store.register_user(self.user.to_string(), "pass")) - self.get_success(self.store.create_profile(self.user.localpart)) - self.get_success( + @defer.inlineCallbacks + def test_get_users_paginate(self): + yield defer.ensureDeferred( + self.store.register_user(self.user.to_string(), "pass") + ) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) + yield defer.ensureDeferred( self.store.set_profile_displayname(self.user.localpart, self.displayname) ) - users, total = self.get_success( + users, total = yield defer.ensureDeferred( self.store.get_users_paginate(0, 10, name="bc", guests=False) ) self.assertEquals(1, total) self.assertEquals(self.displayname, users.pop()["displayname"]) - users, total = self.get_success( + users, total = yield defer.ensureDeferred( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f5b28aed8c4..37cf7bb232f9 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -23,7 +23,6 @@ from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.background_updates import _BackgroundUpdateHandler from synapse.storage.roommember import ProfileInfo from synapse.util import Clock @@ -392,9 +391,7 @@ async def mocked_process_users(*args: Any, **kwargs: Any) -> int: with mock.patch.dict( self.store.db_pool.updates._background_update_handlers, - populate_user_directory_process_users=_BackgroundUpdateHandler( - mocked_process_users, - ), + populate_user_directory_process_users=mocked_process_users, ): self._purge_and_rebuild_user_dir() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0b08d67d435..94b19788d737 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,30 +13,35 @@ # limitations under the License. import logging from typing import Optional +from unittest.mock import Mock + +from twisted.internet import defer +from twisted.internet.defer import succeed from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase -from synapse.types import JsonDict +from synapse.events import FrozenEvent from synapse.visibility import filter_events_for_server -from tests import unittest -from tests.utils import create_room +import tests.unittest +from tests.utils import create_room, setup_test_homeserver logger = logging.getLogger(__name__) TEST_ROOM_ID = "!TEST:ROOM" -class FilterEventsForServerTestCase(unittest.HomeserverTestCase): - def setUp(self) -> None: - super(FilterEventsForServerTestCase, self).setUp() +class FilterEventsForServerTestCase(tests.unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.storage = self.hs.get_storage() - self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) - def test_filtering(self) -> None: + @defer.inlineCallbacks + def test_filtering(self): # # The events to be filtered consist of 10 membership events (it doesn't # really matter if they are joins or leaves, so let's make them joins). @@ -46,20 +51,18 @@ def test_filtering(self) -> None: # # before we do that, we persist some other events to act as state. - self.get_success(self._inject_visibility("@admin:hs", "joined")) + yield self.inject_visibility("@admin:hs", "joined") for i in range(0, 10): - self.get_success(self._inject_room_member("@resident%i:hs" % i)) + yield self.inject_room_member("@resident%i:hs" % i) events_to_filter = [] for i in range(0, 10): user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = self.get_success( - self._inject_room_member(user, extra_content={"a": "b"}) - ) + evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = self.get_success( + filtered = yield defer.ensureDeferred( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -72,31 +75,34 @@ def test_filtering(self) -> None: self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") - def test_erased_user(self) -> None: + @defer.inlineCallbacks + def test_erased_user(self): # 4 message events, from erased and unerased users, with a membership # change in the middle of them. events_to_filter = [] - evt = self.get_success(self._inject_message("@unerased:local_hs")) + evt = yield self.inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@erased:local_hs")) + evt = yield self.inject_message("@erased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_room_member("@joiner:remote_hs")) + evt = yield self.inject_room_member("@joiner:remote_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@unerased:local_hs")) + evt = yield self.inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@erased:local_hs")) + evt = yield self.inject_message("@erased:local_hs") events_to_filter.append(evt) # the erasey user gets erased - self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) + yield defer.ensureDeferred( + self.hs.get_datastore().mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. - filtered = self.get_success( + filtered = yield defer.ensureDeferred( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -117,7 +123,8 @@ def test_erased_user(self) -> None: for i in (1, 4): self.assertNotIn("body", filtered[i].content) - def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: + @defer.inlineCallbacks + def inject_visibility(self, user_id, visibility): content = {"history_visibility": visibility} builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -130,18 +137,18 @@ def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: }, ) - event, context = self.get_success( + event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event - def _inject_room_member( - self, - user_id: str, - membership: str = "join", - extra_content: Optional[JsonDict] = None, - ) -> EventBase: + @defer.inlineCallbacks + def inject_room_member( + self, user_id, membership="join", extra_content: Optional[dict] = None + ): content = {"membership": membership} content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( @@ -155,16 +162,17 @@ def _inject_room_member( }, ) - event, context = self.get_success( + event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event - def _inject_message( - self, user_id: str, content: Optional[JsonDict] = None - ) -> EventBase: + @defer.inlineCallbacks + def inject_message(self, user_id, content=None): if content is None: content = {"body": "testytest", "msgtype": "m.text"} builder = self.event_builder_factory.for_room_version( @@ -177,9 +185,164 @@ def _inject_message( }, ) - event, context = self.get_success( + event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event + + @defer.inlineCallbacks + def test_large_room(self): + # see what happens when we have a large room with hundreds of thousands + # of membership events + + # As above, the events to be filtered consist of 10 membership events, + # where one of them is for a user on the server we are filtering for. + + import cProfile + import pstats + import time + + # we stub out the store, because building up all that state the normal + # way is very slow. + test_store = _TestStore() + + # our initial state is 100000 membership events and one + # history_visibility event. + room_state = [] + + history_visibility_evt = FrozenEvent( + { + "event_id": "$history_vis", + "type": "m.room.history_visibility", + "sender": "@resident_user_0:test.com", + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": {"history_visibility": "joined"}, + } + ) + room_state.append(history_visibility_evt) + test_store.add_event(history_visibility_evt) + + for i in range(0, 100000): + user = "@resident_user_%i:test.com" % (i,) + evt = FrozenEvent( + { + "event_id": "$res_event_%i" % (i,), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": {"membership": "join", "extra": "zzz,"}, + } + ) + room_state.append(evt) + test_store.add_event(evt) + + events_to_filter = [] + for i in range(0, 10): + user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") + evt = FrozenEvent( + { + "event_id": "$evt%i" % (i,), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": {"membership": "join", "extra": "zzz"}, + } + ) + events_to_filter.append(evt) + room_state.append(evt) + + test_store.add_event(evt) + test_store.set_state_ids_for_event( + evt, {(e.type, e.state_key): e.event_id for e in room_state} + ) + + pr = cProfile.Profile() + pr.enable() + + logger.info("Starting filtering") + start = time.time() + + storage = Mock() + storage.main = test_store + storage.state = test_store + + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) + ) + logger.info("Filtering took %f seconds", time.time() - start) + + pr.disable() + with open("filter_events_for_server.profile", "w+") as f: + ps = pstats.Stats(pr, stream=f).sort_stats("cumulative") + ps.print_stats() + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("extra", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["extra"], "zzz") + + test_large_room.skip = "Disabled by default because it's slow" + + +class _TestStore: + """Implements a few methods of the DataStore, so that we can test + filter_events_for_server + + """ + + def __init__(self): + # data for get_events: a map from event_id to event + self.events = {} + + # data for get_state_ids_for_events mock: a map from event_id to + # a map from (type_state_key) -> event_id for the state at that + # event + self.state_ids_for_events = {} + + def add_event(self, event): + self.events[event.event_id] = event + + def set_state_ids_for_event(self, event, state): + self.state_ids_for_events[event.event_id] = state + + def get_state_ids_for_events(self, events, types): + res = {} + include_memberships = False + for (type, state_key) in types: + if type == "m.room.history_visibility": + continue + if type != "m.room.member" or state_key is not None: + raise RuntimeError( + "Unimplemented: get_state_ids with type (%s, %s)" + % (type, state_key) + ) + include_memberships = True + + if include_memberships: + for event_id in events: + res[event_id] = self.state_ids_for_events[event_id] + + else: + k = ("m.room.history_visibility", "") + for event_id in events: + hve = self.state_ids_for_events[event_id][k] + res[event_id] = {k: hve} + + return succeed(res) + + def get_events(self, events): + return succeed({event_id: self.events[event_id] for event_id in events}) + + def are_users_erased(self, users): + return succeed({u: False for u in users}) diff --git a/tests/unittest.py b/tests/unittest.py index eea0903f0574..165aafc57453 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -331,16 +331,17 @@ def wait_on_thread(self, deferred, timeout=10): time.sleep(0.01) def wait_for_background_updates(self) -> None: - """Block until all background database updates have completed. + """ + Block until all background database updates have completed. - Note that callers must ensure there's a store property created on the + Note that callers must ensure that's a store property created on the testcase. """ while not self.get_success( self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def make_homeserver(self, reactor, clock): @@ -499,7 +500,8 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: async def run_bg_updates(): with LoggingContext("run_bg_updates"): - self.get_success(stor.db_pool.updates.run_background_updates(False)) + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 291644eb7dd0..6578f3411e27 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -13,7 +13,6 @@ # limitations under the License. -from typing import List from unittest.mock import Mock from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries @@ -262,17 +261,6 @@ def test_evict(self): self.assertEquals(cache["key4"], [4]) self.assertEquals(cache["key5"], [5, 6]) - def test_zero_size_drop_from_cache(self) -> None: - """Test that `drop_from_cache` works correctly with 0-sized entries.""" - cache: LruCache[str, List[int]] = LruCache(5, size_callback=lambda x: 0) - cache["key1"] = [] - - self.assertEqual(len(cache), 0) - cache.cache["key1"].drop_from_cache() - self.assertIsNone( - cache.pop("key1"), "Cache entry should have been evicted but wasn't" - ) - class TimeEvictionTestCase(unittest.HomeserverTestCase): """Test that time based eviction works correctly.""" From d6fb96e056f79de220d8d59429d89a61498e9af3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 7 Dec 2021 16:51:53 +0000 Subject: [PATCH 002/157] Fix case in `wait_for_background_updates` where `self.store` does not exist (#11331) Pull the DataStore from the HomeServer instance, which always exists. --- changelog.d/11331.misc | 1 + tests/unittest.py | 11 ++++------- 2 files changed, 5 insertions(+), 7 deletions(-) create mode 100644 changelog.d/11331.misc diff --git a/changelog.d/11331.misc b/changelog.d/11331.misc new file mode 100644 index 000000000000..1ab3a6a97591 --- /dev/null +++ b/changelog.d/11331.misc @@ -0,0 +1 @@ +A test helper (`wait_for_background_updates`) no longer depends on classes defining a `store` property. diff --git a/tests/unittest.py b/tests/unittest.py index eea0903f0574..14318483674e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -331,16 +331,13 @@ def wait_on_thread(self, deferred, timeout=10): time.sleep(0.01) def wait_for_background_updates(self) -> None: - """Block until all background database updates have completed. - - Note that callers must ensure there's a store property created on the - testcase. - """ + """Block until all background database updates have completed.""" + store = self.hs.get_datastore() while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() + store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + store.db_pool.updates.do_next_background_update(False), by=0.1 ) def make_homeserver(self, reactor, clock): From 8541809cb952ebf0da2a95dd93eccd5644dab49d Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 8 Dec 2021 05:01:38 -0500 Subject: [PATCH 003/157] Send and handle cross-signing messages using the stable prefix. (#10520) --- changelog.d/10520.misc | 1 + synapse/handlers/e2e_keys.py | 8 ++++++-- synapse/storage/databases/main/devices.py | 4 +++- tests/federation/test_federation_sender.py | 5 +++-- 4 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10520.misc diff --git a/changelog.d/10520.misc b/changelog.d/10520.misc new file mode 100644 index 000000000000..a911e165da80 --- /dev/null +++ b/changelog.d/10520.misc @@ -0,0 +1 @@ +Send and handle cross-signing messages using the stable prefix. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 60c11e3d2128..b2554bda045a 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -65,8 +65,12 @@ def __init__(self, hs: "HomeServer"): else: # Only register this edu handler on master as it requires writing # device updates to the db - # - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "m.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) + # also handle the unstable version + # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d5a4a661cd1a..838a2a6a3dd0 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -274,7 +274,9 @@ async def get_device_updates_by_remote( # add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("m.signing_key_update", result)) + # also send the unstable version + # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b457dad6d263..b2376e2db925 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -266,7 +266,8 @@ def test_upload_signatures(self): ) # expect signing key update edu - self.assertEqual(len(self.edus), 1) + self.assertEqual(len(self.edus), 2) + self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") # sign the devices @@ -491,7 +492,7 @@ def check_signing_key_update_txn( ) -> None: """Check that the txn has an EDU with a signing key update.""" edus = txn["edus"] - self.assertEqual(len(edus), 1) + self.assertEqual(len(edus), 2) def generate_and_upload_device_signing_key( self, user_id: str, device_id: str From ff7cc17b5706ff3f386b15fda668511c0502ab9c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 8 Dec 2021 14:15:14 +0000 Subject: [PATCH 004/157] Improve log messages for stream ids (#11536) Somehow I'd managed to get my database in a pickle with stream ids. These changes were useful to debug. --- changelog.d/11536.misc | 1 + synapse/storage/databases/main/state_deltas.py | 4 +++- synapse/storage/util/id_generators.py | 6 +++--- 3 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11536.misc diff --git a/changelog.d/11536.misc b/changelog.d/11536.misc new file mode 100644 index 000000000000..b9191c111b18 --- /dev/null +++ b/changelog.d/11536.misc @@ -0,0 +1 @@ +Improvements to log messages around handling stream ids. diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 7f3624b12872..188afec332dd 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -56,7 +56,9 @@ async def get_current_state_deltas( prev_stream_id = int(prev_stream_id) # check we're not going backwards - assert prev_stream_id <= max_stream_id + assert ( + prev_stream_id <= max_stream_id + ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4ff3013908a7..b8112e1c0551 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -74,8 +74,6 @@ def get_next(self) -> int: def _load_current_id( db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1 ) -> int: - # debug logging for https://github.com/matrix-org/synapse/issues/7968 - logger.info("initialising stream generator for %s(%s)", table, column) cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) @@ -86,7 +84,9 @@ def _load_current_id( (val,) = result cur.close() current_id = int(val) if val else step - return (max if step > 0 else min)(current_id, step) + res = (max if step > 0 else min)(current_id, step) + logger.info("Initialising stream generator for %s(%s): %i", table, column, res) + return res class AbstractStreamIdTracker(metaclass=abc.ABCMeta): From 365e9482fe18b293f55f78e5f5d2d1107a1d95e1 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 8 Dec 2021 14:54:47 +0000 Subject: [PATCH 005/157] Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`. (#11520) --- changelog.d/11520.misc | 1 + tests/rest/client/test_auth.py | 134 +++++++++++++++++++++------------ 2 files changed, 88 insertions(+), 47 deletions(-) create mode 100644 changelog.d/11520.misc diff --git a/changelog.d/11520.misc b/changelog.d/11520.misc new file mode 100644 index 000000000000..2d84120e19e8 --- /dev/null +++ b/changelog.d/11520.misc @@ -0,0 +1 @@ +Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`. \ No newline at end of file diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 72bbc87b4a0c..27cb856b0acd 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -85,7 +85,7 @@ def recaptcha( channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) channel = self.make_request( "POST", @@ -104,7 +104,7 @@ def test_fallback_captcha(self): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, + HTTPStatus.UNAUTHORIZED, {"username": "user", "type": "m.login.password", "password": "bar"}, ) @@ -116,15 +116,17 @@ def test_fallback_captcha(self): ) # Complete the recaptcha step. - self.recaptcha(session, 200) + self.recaptcha(session, HTTPStatus.OK) # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + self.register( + HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}} + ) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) + channel = self.register(HTTPStatus.OK, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -137,7 +139,8 @@ def test_complete_operation_unknown_session(self): # will be used.) # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} + HTTPStatus.UNAUTHORIZED, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -231,7 +234,9 @@ def test_ui_auth(self): """ # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -242,7 +247,7 @@ def test_ui_auth(self): self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -260,14 +265,16 @@ def test_grandfathered_identifier(self): UIA - check that still works. """ - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -293,7 +300,9 @@ def test_can_change_body(self): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) + channel = self.delete_devices( + HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]} + ) # Grab the session session = channel.json_body["session"] @@ -303,7 +312,7 @@ def test_can_change_body(self): # Make another request providing the UI auth flow, but try to delete the # second device. self.delete_devices( - 200, + HTTPStatus.OK, { "devices": ["dev2"], "auth": { @@ -324,7 +333,9 @@ def test_cannot_change_uri(self): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -338,7 +349,7 @@ def test_cannot_change_uri(self): self.delete_device( self.user_tok, "dev2", - 403, + HTTPStatus.FORBIDDEN, { "auth": { "type": "m.login.password", @@ -361,13 +372,13 @@ def test_can_reuse_session(self): self.login("test", self.user_pass, "dev3") # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) + self.delete_device(self.user_tok, "dev2", HTTPStatus.OK) # Move the clock forward past the validation timeout. self.reactor.advance(6) # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) + channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED) # Grab the session session = channel.json_body["session"] @@ -378,7 +389,7 @@ def test_can_reuse_session(self): self.delete_device( self.user_tok, "dev3", - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -393,7 +404,7 @@ def test_can_reuse_session(self): # due to re-using the previous session. # # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) + self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK) @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) @@ -413,7 +424,9 @@ def test_ui_auth_via_sso(self): self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # check that SSO is offered flows = channel.json_body["flows"] @@ -426,13 +439,13 @@ def test_ui_auth_via_sso(self): ) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # and now the delete request should succeed. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, body={"auth": {"session": session_id}}, ) @@ -445,13 +458,15 @@ def test_does_not_offer_password_for_sso_user(self): # now call the device deletion API: we should get the option to auth with SSO # and not password. - channel = self.delete_device(user_tok, device_id, 401) + channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -463,7 +478,9 @@ def test_offers_both_flows_for_upgraded_user(self): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] # we have no particular expectations of ordering here @@ -480,7 +497,9 @@ def test_ui_auth_fails_for_incorrect_sso_user(self): self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertIn({"stages": ["m.login.sso"]}, flows) @@ -496,7 +515,10 @@ def test_ui_auth_fails_for_incorrect_sso_user(self): # ... and the delete op should now fail with a 403 self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + self.user_tok, + self.device_id, + HTTPStatus.FORBIDDEN, + body={"auth": {"session": session_id}}, ) @@ -551,7 +573,9 @@ def test_login_issue_refresh_token(self): login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertEqual( + login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result + ) self.assertNotIn("refresh_token", login_without_refresh.json_body) login_with_refresh = self.make_request( @@ -559,7 +583,9 @@ def test_login_issue_refresh_token(self): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertEqual( + login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result + ) self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) @@ -577,7 +603,9 @@ def test_register_issue_refresh_token(self): }, ) self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result + register_without_refresh.code, + HTTPStatus.OK, + register_without_refresh.result, ) self.assertNotIn("refresh_token", register_without_refresh.json_body) @@ -591,7 +619,9 @@ def test_register_issue_refresh_token(self): "refresh_token": True, }, ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertEqual( + register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result + ) self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) @@ -610,14 +640,14 @@ def test_token_refresh(self): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_response = self.make_request( "POST", "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn("access_token", refresh_response.json_body) self.assertIn("refresh_token", refresh_response.json_body) self.assertIn("expires_in_ms", refresh_response.json_body) @@ -648,7 +678,7 @@ def test_refreshable_access_token_expiration(self): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) self.assertApproximates( login_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -658,7 +688,7 @@ def test_refreshable_access_token_expiration(self): "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -705,7 +735,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self) "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result) self.assertApproximates( login_response1.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -716,7 +746,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self) "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response2.code, 200, login_response2.result) + self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result) nonrefreshable_access_token = login_response2.json_body["access_token"] # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) @@ -818,7 +848,7 @@ def test_ultimate_session_expiry(self): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_token = login_response.json_body["refresh_token"] # Advance shy of 2 minutes into the future @@ -826,7 +856,7 @@ def test_ultimate_session_expiry(self): # Refresh our session. The refresh token should still be valid right now. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn( "refresh_token", refresh_response.json_body, @@ -846,7 +876,9 @@ def test_ultimate_session_expiry(self): # This should fail because the refresh token's lifetime has also been # diminished as our session expired. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 403, refresh_response.result) + self.assertEqual( + refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result + ) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -875,7 +907,7 @@ def test_refresh_token_invalidation(self): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) # This first refresh should work properly first_refresh_response = self.make_request( @@ -884,7 +916,7 @@ def test_refresh_token_invalidation(self): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result + first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result ) # This one as well, since the token in the first one was never used @@ -894,7 +926,7 @@ def test_refresh_token_invalidation(self): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result + second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result ) # This one should not, since the token from the first refresh is not valid anymore @@ -904,7 +936,9 @@ def test_refresh_token_invalidation(self): {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result + third_refresh_response.code, + HTTPStatus.UNAUTHORIZED, + third_refresh_response.result, ) # The associated access token should also be invalid @@ -913,7 +947,9 @@ def test_refresh_token_invalidation(self): "/_matrix/client/r0/account/whoami", access_token=first_refresh_response.json_body["access_token"], ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result + ) # But all other tokens should work (they will expire after some time) for access_token in [ @@ -923,7 +959,9 @@ def test_refresh_token_invalidation(self): whoami_response = self.make_request( "GET", "/_matrix/client/r0/account/whoami", access_token=access_token ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.OK, whoami_response.result + ) # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( @@ -932,7 +970,9 @@ def test_refresh_token_invalidation(self): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result + fourth_refresh_response.code, + HTTPStatus.FORBIDDEN, + fourth_refresh_response.result, ) # But refreshing from the last valid refresh token still works @@ -942,5 +982,5 @@ def test_refresh_token_invalidation(self): {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result + fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) From 83a74d9350e731cc0a7f119cf89aa1bd87638b84 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 8 Dec 2021 15:31:17 +0000 Subject: [PATCH 006/157] Document the usage of refresh tokens. (#11427) Co-authored-by: David Robertson --- changelog.d/11427.doc | 1 + docs/SUMMARY.md | 1 + .../user_authentication/refresh_tokens.md | 139 ++++++++++++++++++ 3 files changed, 141 insertions(+) create mode 100644 changelog.d/11427.doc create mode 100644 docs/usage/configuration/user_authentication/refresh_tokens.md diff --git a/changelog.d/11427.doc b/changelog.d/11427.doc new file mode 100644 index 000000000000..01cdfcf2b7e4 --- /dev/null +++ b/changelog.d/11427.doc @@ -0,0 +1 @@ +Document the usage of refresh tokens. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index b05af6d69051..11f597b3edb8 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -30,6 +30,7 @@ - [SSO Mapping Providers](sso_mapping_providers.md) - [Password Auth Providers](password_auth_providers.md) - [JSON Web Tokens](jwt.md) + - [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md) - [Registration Captcha](CAPTCHA_SETUP.md) - [Application Services](application_services.md) - [Server Notices](server_notices.md) diff --git a/docs/usage/configuration/user_authentication/refresh_tokens.md b/docs/usage/configuration/user_authentication/refresh_tokens.md new file mode 100644 index 000000000000..23b3cddae054 --- /dev/null +++ b/docs/usage/configuration/user_authentication/refresh_tokens.md @@ -0,0 +1,139 @@ +# Refresh Tokens + +Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible). + + +[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens + + +## Background and motivation + +Synapse users' sessions are identified by **access tokens**; access tokens are +issued to users on login. Each session gets a unique access token which identifies +it; the access token must be kept secret as it grants access to the user's account. + +Traditionally, these access tokens were eternally valid (at least until the user +explicitly chose to log out). + +In some cases, it may be desirable for these access tokens to expire so that the +potential damage caused by leaking an access token is reduced. +On the other hand, forcing a user to re-authenticate (log in again) often might +be too much of an inconvenience. + +**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst +still getting most of the benefits of short access token lifetimes. +Refresh tokens are also a concept present in OAuth 2 — further reading is available +[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5). + +When refresh tokens are in use, both an access token and a refresh token will be +issued to users on login. The access token will expire after a predetermined amount +of time, but otherwise works in the same way as before. When the access token is +close to expiring (or has expired), the user's client should present the homeserver +(Synapse) with the refresh token. + +The homeserver will then generate a new access token and refresh token for the user +and return them. The old refresh token is invalidated and can not be used again*. + +Finally, refresh tokens also make it possible for sessions to be logged out if they +are inactive for too long, before the session naturally ends; see the configuration +guide below. + + +*To prevent issues if clients lose connection half-way through refreshing a token, +the refresh token is only invalidated once the new access token has been used at +least once. For all intents and purposes, the above simplification is sufficient. + + +## Caveats + +There are some caveats: + +* If a third party gets both your access token and refresh token, they will be able to + continue to enjoy access to your session. + * This is still an improvement because you (the user) will notice when *your* + session expires and you're not able to use your refresh token. + That would be a giveaway that someone else has compromised your session. + You would be able to log in again and terminate that session. + Previously (with long-lived access tokens), a third party that has your access + token could go undetected for a very long time. +* Clients need to implement support for refresh tokens in order for them to be a + useful mechanism. + * It is up to homeserver administrators if they want to issue long-lived access + tokens to clients not implementing refresh tokens. + * For compatibility, it is likely that they should, at least until client support + is widespread. + * Users with clients that support refresh tokens will still benefit from the + added security; it's not possible to downgrade a session to using long-lived + access tokens so this effectively gives users the choice. + * In a closed environment where all users use known clients, this may not be + an issue as the homeserver administrator can know if the clients have refresh + token support. In that case, the non-refreshable access token lifetime + may be set to a short duration so that a similar level of security is provided. + + +## Configuration Guide + +The following configuration options, in the `registration` section, are related: + +* `session_lifetime`: maximum length of a session, even if it's refreshed. + In other words, the client must log in again after this time period. + In most cases, this can be unset (infinite) or set to a long time (years or months). +* `refreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients supporting refresh tokens. + This should be short; a good value might be 5 minutes (`5m`). +* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients which don't support refresh tokens. + Make this short if you want to effectively force use of refresh tokens. + Make this long if you don't want to inconvenience users of clients which don't + support refresh tokens (by forcing them to frequently re-authenticate using + login credentials). +* `refresh_token_lifetime`: lifetime of refresh tokens. + In other words, the client must refresh within this time period to maintain its session. + Unless you want to log inactive sessions out, it is often fine to use a long + value here or even leave it unset (infinite). + Beware that making it too short will inconvenience clients that do not connect + very often, including mobile clients and clients of infrequent users (by making + it more difficult for them to refresh in time, which may force them to need to + re-authenticate using login credentials). + +**Note:** All four options above only apply when tokens are created (by logging in or refreshing). +Changes to these settings do not apply retroactively. + + +### Using refresh token expiry to log out inactive sessions + +If you'd like to force sessions to be logged out upon inactivity, you can enable +refreshable access token expiry and refresh token expiry. + +This works because a client must refresh at least once within a period of +`refresh_token_lifetime` in order to maintain valid credentials to access the +account. + +(It's suggested that `refresh_token_lifetime` should be longer than +`refreshable_access_token_lifetime` and this section assumes that to be the case +for simplicity.) + +Note: this will only affect sessions using refresh tokens. You may wish to +set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed +by clients that do not support refresh tokens. + + +#### Choosing values that guarantee permitting some inactivity + +It may be desirable to permit some short periods of inactivity, for example to +accommodate brief outages in client connectivity. + +The following model aims to provide guidance for choosing `refresh_token_lifetime` +and `refreshable_access_token_lifetime` to satisfy requirements of the form: + +1. inactivity longer than `L` **MUST** cause the session to be logged out; and +2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out. + +This model makes the weakest assumption that all active clients will refresh as +needed to maintain an active access token, but no sooner. +*In reality, clients may refresh more often than this model assumes, but the +above requirements will still hold.* + +To satisfy the above model, +* `refresh_token_lifetime` should be set to `L`; and +* `refreshable_access_token_lifetime` should be set to `L - S`. From 7ecaa3b976b04dc5b2c6786aa18845016b80dd01 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 8 Dec 2021 17:59:40 +0100 Subject: [PATCH 007/157] Clean up `synapse.rest.admin` (#11535) --- changelog.d/11535.misc | 1 + synapse/rest/admin/__init__.py | 4 +- synapse/rest/admin/background_updates.py | 16 ++--- synapse/rest/admin/devices.py | 20 +++--- synapse/rest/admin/event_reports.py | 2 - synapse/rest/admin/federation.py | 2 +- synapse/rest/admin/groups.py | 2 +- synapse/rest/admin/media.py | 60 ++++++------------ synapse/rest/admin/registration_tokens.py | 3 - synapse/rest/admin/rooms.py | 70 +++++++-------------- synapse/rest/admin/server_notice_servlet.py | 4 +- synapse/rest/admin/statistics.py | 22 +++---- synapse/rest/admin/username_available.py | 2 +- synapse/rest/admin/users.py | 51 +++++++-------- tests/rest/admin/test_statistics.py | 2 +- 15 files changed, 96 insertions(+), 165 deletions(-) create mode 100644 changelog.d/11535.misc diff --git a/changelog.d/11535.misc b/changelog.d/11535.misc new file mode 100644 index 000000000000..580ac354ab7e --- /dev/null +++ b/changelog.d/11535.misc @@ -0,0 +1 @@ +Clean up `synapse.rest.admin`. \ No newline at end of file diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c499afd4be57..701c609c1208 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -108,7 +108,7 @@ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class PurgeHistoryRestServlet(RestServlet): PATTERNS = admin_patterns( - "/purge_history/(?P[^/]*)(/(?P[^/]+))?" + "/purge_history/(?P[^/]*)(/(?P[^/]*))?$" ) def __init__(self, hs: "HomeServer"): @@ -195,7 +195,7 @@ async def on_POST( class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_history_status/(?P[^/]+)") + PATTERNS = admin_patterns("/purge_history_status/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 479672d4d568..6ec00ce0b9a8 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -22,7 +22,7 @@ parse_json_object_from_request, ) from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict if TYPE_CHECKING: @@ -41,8 +41,7 @@ def __init__(self, hs: "HomeServer"): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -51,8 +50,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) @@ -84,8 +82,7 @@ def __init__(self, hs: "HomeServer"): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -111,15 +108,14 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class BackgroundUpdateStartJobRestServlet(RestServlet): """Allows to start specific background updates""" - PATTERNS = admin_patterns("/background_updates/start_job") + PATTERNS = admin_patterns("/background_updates/start_job$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._store = hs.get_datastore() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["job_name"]) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 2e5a6600d337..062a33d28d15 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str, device_id: str @@ -53,7 +53,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -71,7 +71,7 @@ async def on_DELETE( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -87,7 +87,7 @@ async def on_PUT( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -109,14 +109,10 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -124,7 +120,7 @@ async def on_GET( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -144,10 +140,10 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/delete_devices$", "v2") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, user_id: str @@ -155,7 +151,7 @@ async def on_POST( await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 5ee8b11110e0..38477f8eadeb 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 744687be35fc..50d88c91091b 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet): 200 OK with details of a destination if success otherwise an error. """ - PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$") + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index a27110388f4f..cd697e180ef6 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -30,7 +30,7 @@ class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups""" - PATTERNS = admin_patterns("/delete_group/(?P[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.group_server = hs.get_groups_server_handler() diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 9e23e2d8fc00..7236e4027fa7 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,7 +17,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P[^/]+)/media/quarantine$"), + *admin_patterns("/room/(?P[^/]*)/media/quarantine$"), # This path kept around for legacy reasons - *admin_patterns("/quarantine_media/(?P[^/]+)"), + *admin_patterns("/quarantine_media/(?P[^/]*)$"), ] def __init__(self, hs: "HomeServer"): @@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine$") + PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/quarantine/(?P[^/]+)/(?P[^/]+)" + "/media/quarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/unquarantine/(?P[^/]+)/(?P[^/]+)" + "/media/unquarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -138,8 +138,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST( self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info( "Remove from quarantine local media by ID: %s/%s", server_name, media_id @@ -154,7 +153,7 @@ async def on_POST( class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -163,8 +162,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Protecting local media by ID: %s", media_id) @@ -177,7 +175,7 @@ async def on_POST( class UnprotectMediaByID(RestServlet): """Unprotect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/unprotect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -186,8 +184,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Unprotecting local media by ID: %s", media_id) @@ -200,7 +197,7 @@ async def on_POST( class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -209,10 +206,7 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - is_admin = await self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + await assert_requester_is_admin(self.auth, request) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) @@ -254,7 +248,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class DeleteMediaByID(RestServlet): """Delete local media by a given ID. Removes it from this server.""" - PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") + PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P[^/]+)/delete$") + PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet): media that exist given for this user """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -403,16 +397,7 @@ async def on_GET( request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -470,16 +455,7 @@ async def on_DELETE( request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 891b98c0888a..04948b640834 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/new$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.clock = hs.get_clock() self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 829e86675aba..17c6df1cc8c7 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet): If 'purge' is true, it will remove all traces of a room from the database. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -123,7 +123,7 @@ async def on_DELETE( class DeleteRoomStatusByRoomIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete_status$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/delete_status$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -160,7 +160,7 @@ async def on_GET( class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -193,35 +193,17 @@ def __init__(self, hs: "HomeServer"): self.admin_handler = hs.get_admin_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) - if order_by not in ( - RoomSortOrder.ALPHABETICAL.value, - RoomSortOrder.SIZE.value, - RoomSortOrder.NAME.value, - RoomSortOrder.CANONICAL_ALIAS.value, - RoomSortOrder.JOINED_MEMBERS.value, - RoomSortOrder.JOINED_LOCAL_MEMBERS.value, - RoomSortOrder.VERSION.value, - RoomSortOrder.CREATOR.value, - RoomSortOrder.ENCRYPTION.value, - RoomSortOrder.FEDERATABLE.value, - RoomSortOrder.PUBLIC.value, - RoomSortOrder.JOIN_RULES.value, - RoomSortOrder.GUEST_ACCESS.value, - RoomSortOrder.HISTORY_VISIBILITY.value, - RoomSortOrder.STATE_EVENTS.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) + order_by = parse_string( + request, + "order_by", + default=RoomSortOrder.NAME.value, + allowed_values=[sort_order.value for sort_order in RoomSortOrder], + ) search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": @@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet): TODO: Add on_POST to allow room creation without joining the room """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() @@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet): Get members list of a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/members$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet): Get full state within a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/state$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -436,8 +415,7 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) if not ret: @@ -454,14 +432,14 @@ async def on_GET( class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P[^/]*)") + PATTERNS = admin_patterns("/join/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, room_identifier: str @@ -477,7 +455,7 @@ async def on_POST( assert_params_in_dict(content, ["user_id"]) target_user = UserID.from_string(content["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "This endpoint can only be used with local users", @@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): } """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() @@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): GET /_synapse/admin/v1/rooms//forward_extremities """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() async def on_DELETE( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -710,8 +685,7 @@ async def on_DELETE( async def on_GET( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -793,7 +767,7 @@ class BlockRoomRestServlet(RestServlet): On GET: Get blocking status of room and user who has blocked this room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/block$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/block$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b295fb078bc7..15da9cd88153 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.server_notices_manager = hs.get_server_notices_manager() self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) + self.is_mine = hs.is_mine def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" @@ -88,7 +88,7 @@ async def on_POST( ) target_user = UserID.from_string(body["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" ) diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index ca41fd45f2bd..7a6546372eef 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet): PATTERNS = admin_patterns("/statistics/users/media$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,19 +44,16 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) order_by = parse_string( - request, "order_by", default=UserSortOrder.USER_ID.value + request, + "order_by", + default=UserSortOrder.USER_ID.value, + allowed_values=( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ), ) - if order_by not in ( - UserSortOrder.MEDIA_LENGTH.value, - UserSortOrder.MEDIA_COUNT.value, - UserSortOrder.USER_ID.value, - UserSortOrder.DISPLAYNAME.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) start = parse_integer(request, "from", default=0) if start < 0: diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py index 2bf1472967dd..5353dc368235 100644 --- a/synapse/rest/admin/username_available.py +++ b/synapse/rest/admin/username_available.py @@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/username_available") + PATTERNS = admin_patterns("/username_available$") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2a60b602b1f8..db678da4cf14 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -126,7 +125,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class UserRestServletV2(RestServlet): - PATTERNS = admin_patterns("/users/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/users/(?P[^/]*)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = admin_patterns("/register") + PATTERNS = admin_patterns("/register$") NONCE_TIMEOUT = 60 def __init__(self, hs: "HomeServer"): @@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -575,7 +574,7 @@ async def on_GET( if target_user != auth_user: await assert_user_is_admin(self.auth, auth_user) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) @@ -584,7 +583,7 @@ async def on_GET( class DeactivateAccountRestServlet(RestServlet): - PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = admin_patterns("/reset_password/(?P[^/]*)") + PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = admin_patterns("/search_users/(?P[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, target_user_id: str @@ -740,7 +737,7 @@ async def on_GET( # if not is_admin and target_user != auth_user: # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) @@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -790,7 +787,7 @@ async def on_GET( target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -813,7 +810,7 @@ async def on_PUT( assert_params_in_dict(body, ["admin"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet): Get room list of an user. """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/joined_rooms$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/joined_rooms$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str @@ -921,7 +918,7 @@ async def on_POST( await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" ) @@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet): {} """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1001,7 +998,7 @@ async def on_DELETE( ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit") + PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): @@ -1068,7 +1065,7 @@ async def on_POST( ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) @@ -1113,7 +1110,7 @@ async def on_DELETE( ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 7cb8ec57bad9..f6e85fdaadcd 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -92,7 +92,7 @@ def test_invalid_parameter(self) -> None: channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative from channel = self.make_request( From d93362d87fbbf4941da06ade65eaf5df1672bccb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Dec 2021 12:26:29 -0500 Subject: [PATCH 008/157] Add a constant for receipt types (m.read). (#11531) And expand some type hints in the receipts storage module. --- changelog.d/11531.misc | 1 + synapse/api/constants.py | 4 + synapse/handlers/receipts.py | 6 +- synapse/handlers/sync.py | 4 +- synapse/push/push_tools.py | 3 +- synapse/rest/client/notifications.py | 3 +- synapse/rest/client/read_marker.py | 6 +- synapse/rest/client/receipts.py | 4 +- synapse/storage/databases/main/receipts.py | 101 ++++++++++++++------- 9 files changed, 87 insertions(+), 45 deletions(-) create mode 100644 changelog.d/11531.misc diff --git a/changelog.d/11531.misc b/changelog.d/11531.misc new file mode 100644 index 000000000000..ed6ef3bb3e56 --- /dev/null +++ b/changelog.d/11531.misc @@ -0,0 +1 @@ +Add a receipt types constant for `m.read`. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f7d29b431936..52c083a20b9c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -253,5 +253,9 @@ class GuestAccess: FORBIDDEN: Final = "forbidden" +class ReceiptTypes: + READ: Final = "m.read" + + class ReadReceiptEventFields: MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4911a1153519..5cb1ff749d92 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -178,7 +178,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: for event_id in content.keys(): event_content = content.get(event_id, {}) - m_read = event_content.get("m.read", {}) + m_read = event_content.get(ReceiptTypes.READ, {}) # If m_read is missing copy over the original event_content as there is nothing to process here if not m_read: @@ -206,7 +206,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: # Set new users unless empty if len(new_users.keys()) > 0: - new_event["content"][event_id] = {"m.read": new_users} + new_event["content"][event_id] = {ReceiptTypes.READ: new_users} # Append new_event to visible_events unless empty if len(new_event["content"].keys()) > 0: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f3039c3c3fb7..96f37e9f4204 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1046,7 +1046,7 @@ async def unread_notifs_for_room_id( last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read", + receipt_type=ReceiptTypes.READ, ) notifs = await self.store.get_unread_event_push_actions_by_room_for_user( diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 9c85200c0fb4..da641aca477c 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict +from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ) badge = len(invites) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index d1d8a984c630..b12a332776e4 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ReceiptTypes from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -54,7 +55,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" + user_id, ReceiptTypes.READ ) notif_event_ids = [pa["event_id"] for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 43c04fac6fdb..f51be511d1f4 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,7 +48,7 @@ async def on_POST( await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + read_event_id = body.get(ReceiptTypes.READ, None) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): @@ -62,7 +62,7 @@ async def on_POST( if read_event_id: await self.receipts_handler.received_client_receipt( room_id, - "m.read", + ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, hidden=hidden, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 2b25b9aad6a3..b24ad2d1be13 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http import get_request_user_agent from synapse.http.server import HttpServer @@ -53,7 +53,7 @@ async def on_POST( ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: raise SynapseError(400, "Receipt type must be 'm.read'") # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c99f8aebdbdd..9c5625c8bbb8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,14 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from twisted.internet import defer +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -78,17 +89,13 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ + def get_max_receipt_stream_id(self) -> int: + """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @cached() - async def get_users_with_read_receipts_in_room(self, room_id): - receipts = await self.get_receipts_for_room(room_id, "m.read") + async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: + receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -119,7 +126,9 @@ async def get_last_receipt_event_id_for_user( ) @cached(num_args=2) - async def get_receipts_for_user(self, user_id, receipt_type): + async def get_receipts_for_user( + self, user_id: str, receipt_type: str + ) -> Dict[str, str]: rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -129,8 +138,10 @@ async def get_receipts_for_user(self, user_id, receipt_type): return {row["room_id"]: row["event_id"] for row in rows} - async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): + async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_type: str + ) -> JsonDict: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -209,10 +220,10 @@ async def get_linearized_receipts_for_room( @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" @@ -250,11 +261,13 @@ def f(txn): list_name="room_ids", num_args=3, ) - async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms( + self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -323,7 +336,7 @@ async def get_linearized_receipts_for_all_rooms( A dictionary of roomids to a list of receipts. """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -379,7 +392,7 @@ async def get_users_sent_receipts_between( if last_id == current_id: return defer.succeed([]) - def _get_users_sent_receipts_between_txn(txn): + def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? @@ -419,7 +432,9 @@ async def get_all_updated_receipts( if last_id == current_id: return [], current_id, False - def get_all_updated_receipts_txn(txn): + def get_all_updated_receipts_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized @@ -446,8 +461,8 @@ def get_all_updated_receipts_txn(txn): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str - ): - if receipt_type != "m.read": + ) -> None: + if receipt_type != ReceiptTypes.READ: return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( @@ -461,7 +476,9 @@ def _invalidate_get_users_with_receipts_in_room( self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + def invalidate_caches_for_receipt( + self, room_id: str, receipt_type: str, user_id: str + ) -> None: self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( @@ -482,11 +499,18 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + data: JsonDict, + stream_id: int, + ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR - Returns: int|None + Returns: None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) @@ -550,7 +574,7 @@ def insert_linearized_receipt_txn( lock=False, ) - if receipt_type == "m.read" and stream_ordering is not None: + if receipt_type == ReceiptTypes.READ and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -580,7 +604,7 @@ async def insert_receipt( else: # we need to points in graph -> linearized form. # TODO: Make this better. - def graph_to_linear(txn): + def graph_to_linear(txn: LoggingTransaction) -> str: clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) @@ -634,11 +658,16 @@ def graph_to_linear(txn): return stream_id, max_persisted_id async def insert_graph_receipt( - self, room_id, receipt_type, user_id, event_ids, data - ): + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -649,8 +678,14 @@ async def insert_graph_receipt( ) def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) From afa0a5e4fc4e6badd6f3bc8393fe4d6abb2d834c Mon Sep 17 00:00:00 2001 From: Robert Long Date: Thu, 9 Dec 2021 03:02:05 -0800 Subject: [PATCH 009/157] Allow guests to send state events (#11378) --- changelog.d/11378.feature | 1 + synapse/rest/client/room.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11378.feature diff --git a/changelog.d/11378.feature b/changelog.d/11378.feature new file mode 100644 index 000000000000..524bf84f323e --- /dev/null +++ b/changelog.d/11378.feature @@ -0,0 +1 @@ +Allow guests to send state events per [MSC3419](https://github.com/matrix-org/matrix-doc/pull/3419). \ No newline at end of file diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index f48e2e6ca248..60719331b640 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -187,7 +187,7 @@ async def on_PUT( state_key: str, txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if txn_id: set_tag("txn_id", txn_id) From b3bcacf3c1c72bfadeb46fe4d0198ca155a8c615 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 9 Dec 2021 12:23:34 +0100 Subject: [PATCH 010/157] Add missing `errcode` to `parse_string` and `parse_boolean` (#11542) --- changelog.d/11542.misc | 1 + synapse/http/servlet.py | 4 ++-- tests/rest/admin/test_federation.py | 4 ++-- tests/rest/admin/test_media.py | 2 +- tests/rest/admin/test_statistics.py | 2 +- tests/rest/admin/test_user.py | 12 ++++++------ 6 files changed, 13 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11542.misc diff --git a/changelog.d/11542.misc b/changelog.d/11542.misc new file mode 100644 index 000000000000..f61416503766 --- /dev/null +++ b/changelog.d/11542.misc @@ -0,0 +1 @@ +Add missing `errcode` to `parse_string` and `parse_boolean`. \ No newline at end of file diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6dd9b9ad0358..1627225f286d 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -246,7 +246,7 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) else: if required: message = "Missing boolean query parameter %r" % (name,) @@ -414,7 +414,7 @@ def _parse_string_value( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message) + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) else: return value_str diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 5188499ef2d6..d1cd5b075157 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -95,7 +95,7 @@ def test_invalid_parameter(self): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -105,7 +105,7 @@ def test_invalid_parameter(self): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid destination channel = self.make_request( diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 81e578fd26c1..3f727788cec8 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -360,7 +360,7 @@ def test_invalid_parameter(self) -> None: channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index f6e85fdaadcd..7cb8ec57bad9 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -92,7 +92,7 @@ def test_invalid_parameter(self) -> None: channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from channel = self.make_request( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4fedd5fd0851..294d429ce12b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -608,7 +608,7 @@ def test_invalid_parameter(self): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid deactivated channel = self.make_request( @@ -618,7 +618,7 @@ def test_invalid_parameter(self): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # unkown order_by channel = self.make_request( @@ -628,7 +628,7 @@ def test_invalid_parameter(self): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -638,7 +638,7 @@ def test_invalid_parameter(self): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self): """ @@ -2896,7 +2896,7 @@ def test_invalid_parameter(self, method: str): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order channel = self.make_request( @@ -2906,7 +2906,7 @@ def test_invalid_parameter(self, method: str): ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit channel = self.make_request( From b47d10dc46e4644c432f017d5b2129ff7a349166 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 9 Dec 2021 06:41:27 -0500 Subject: [PATCH 011/157] Support unprefixed versions of fallback key property names. (#11541) --- changelog.d/11541.misc | 1 + synapse/handlers/e2e_keys.py | 4 +++- synapse/rest/client/sync.py | 3 +++ tests/handlers/test_e2e_keys.py | 30 +++++++++++++++++++++++++----- 4 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 changelog.d/11541.misc diff --git a/changelog.d/11541.misc b/changelog.d/11541.misc new file mode 100644 index 000000000000..31c72c2a20d0 --- /dev/null +++ b/changelog.d/11541.misc @@ -0,0 +1 @@ +Support unprefixed versions of fallback key property names. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b2554bda045a..14360b4e4045 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -580,7 +580,9 @@ async def upload_keys_for_user( log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) - fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None) + fallback_keys = keys.get("fallback_keys") or keys.get( + "org.matrix.msc2732.fallback_keys" + ) if fallback_keys and isinstance(fallback_keys, dict): log_kv( { diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 88e4f5e0630f..dd90ffa12397 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -293,6 +293,9 @@ async def encode_response( response[ "org.matrix.msc2732.device_unused_fallback_key_types" ] = sync_result.device_unused_fallback_key_types + response[ + "device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index f0723892e416..ddcf3ee34886 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -161,8 +161,9 @@ def test_claim_one_time_key(self): def test_fallback_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" - fallback_key = {"alg1:k1": "key1"} - fallback_key2 = {"alg1:k2": "key2"} + fallback_key = {"alg1:k1": "fallback_key1"} + fallback_key2 = {"alg1:k2": "fallback_key2"} + fallback_key3 = {"alg1:k2": "fallback_key3"} otk = {"alg1:k2": "key2"} # we shouldn't have any unused fallback keys yet @@ -175,7 +176,7 @@ def test_fallback_key(self): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key}, + {"fallback_keys": fallback_key}, ) ) @@ -220,7 +221,7 @@ def test_fallback_key(self): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key}, + {"fallback_keys": fallback_key}, ) ) @@ -234,7 +235,7 @@ def test_fallback_key(self): self.handler.upload_keys_for_user( local_user, device_id, - {"org.matrix.msc2732.fallback_keys": fallback_key2}, + {"fallback_keys": fallback_key2}, ) ) @@ -271,6 +272,25 @@ def test_fallback_key(self): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, ) + # using the unstable prefix should also set the fallback key + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key3}, + ) + ) + + res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, + ) + def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname From 941ebe49ffc32c6d67b487094a6f8e1c290e93bc Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 9 Dec 2021 12:58:25 +0100 Subject: [PATCH 012/157] Use HTTPStatus constants in place of literals in `synapse.http` (#11543) --- changelog.d/11543.misc | 1 + synapse/http/client.py | 15 +++++--- synapse/http/matrixfederationclient.py | 3 +- synapse/http/servlet.py | 47 ++++++++++++++++++-------- 4 files changed, 47 insertions(+), 19 deletions(-) create mode 100644 changelog.d/11543.misc diff --git a/changelog.d/11543.misc b/changelog.d/11543.misc new file mode 100644 index 000000000000..99817d71a433 --- /dev/null +++ b/changelog.d/11543.misc @@ -0,0 +1 @@ +Use HTTPStatus constants in place of literals in `synapse.http`. \ No newline at end of file diff --git a/synapse/http/client.py b/synapse/http/client.py index b5a2d333a6ce..fbbeceabeb6a 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,6 +14,7 @@ # limitations under the License. import logging import urllib.parse +from http import HTTPStatus from io import BytesIO from typing import ( TYPE_CHECKING, @@ -280,7 +281,9 @@ def request( ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info("Blocking access to %s due to blacklist" % (ip_address,)) - e = SynapseError(403, "IP address blocked by IP blacklist entry") + e = SynapseError( + HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry" + ) return defer.fail(Failure(e)) return self._agent.request( @@ -719,7 +722,9 @@ async def get_file( if response.code > 299: logger.warning("Got %d when downloading %s" % (response.code, url)) - raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) + raise SynapseError( + HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN + ) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it @@ -731,12 +736,14 @@ async def get_file( ) except BodyExceededMaxSize: raise SynapseError( - 502, + HTTPStatus.BAD_GATEWAY, "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) except Exception as e: - raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e + raise SynapseError( + HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e) + ) from e return ( length, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 203d723d4120..deedde0b5b37 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,6 +19,7 @@ import sys import typing import urllib.parse +from http import HTTPStatus from io import BytesIO, StringIO from typing import ( TYPE_CHECKING, @@ -1154,7 +1155,7 @@ async def get_file( request.destination, msg, ) - raise SynapseError(502, msg, Codes.TOO_LARGE) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1627225f286d..e543cc6e01e8 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,6 +14,7 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Iterable, @@ -137,11 +138,15 @@ def parse_integer_from_args( return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing integer query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -246,11 +251,15 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing boolean query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -313,7 +322,7 @@ def parse_bytes_from_args( return args[name_bytes][0] elif required: message = "Missing string query parameter %s" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM) return default @@ -407,14 +416,16 @@ def _parse_string_value( try: value_str = value.decode(encoding) except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding) + ) if allowed_values is not None and value_str not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM) else: return value_str @@ -510,7 +521,9 @@ def parse_strings_from_args( else: if required: message = "Missing string query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) return default @@ -638,7 +651,7 @@ def parse_json_value_from_request( try: content_bytes = request.content.read() # type: ignore except Exception: - raise SynapseError(400, "Error reading JSON content.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.") if not content_bytes and allow_empty_body: return None @@ -647,7 +660,9 @@ def parse_json_value_from_request( content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes) - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON + ) return content @@ -673,7 +688,7 @@ def parse_json_object_from_request( if not isinstance(content, dict): message = "Content must be a JSON object." - raise SynapseError(400, message, errcode=Codes.BAD_JSON) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) return content @@ -685,7 +700,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent.append(k) if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM + ) class RestServlet: @@ -758,10 +775,12 @@ async def resolve_room_id( resolved_room_id = room_id.to_string() else: raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) + HTTPStatus.BAD_REQUEST, + "%s was not legal room ID or room alias" % (room_identifier,), ) if not resolved_room_id: raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier + HTTPStatus.BAD_REQUEST, + "Unknown room ID or room alias %s" % room_identifier, ) return resolved_room_id, remote_room_hosts From 0cc3bf97b4399234cf20f52ae4c09d03661225ff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Dec 2021 11:15:46 -0500 Subject: [PATCH 013/157] Additional type hints for the config module, part 2. (#11480) --- changelog.d/11480.misc | 1 + synapse/config/key.py | 36 +++++++++++++++++++++--------------- synapse/config/metrics.py | 6 ++++-- synapse/config/server.py | 2 +- synapse/config/tls.py | 2 +- 5 files changed, 28 insertions(+), 19 deletions(-) create mode 100644 changelog.d/11480.misc diff --git a/changelog.d/11480.misc b/changelog.d/11480.misc new file mode 100644 index 000000000000..aadc938b2be3 --- /dev/null +++ b/changelog.d/11480.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.config` module. diff --git a/synapse/config/key.py b/synapse/config/key.py index 035ee2416bd6..ee83c6c06b7f 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,12 +16,14 @@ import hashlib import logging import os -from typing import Any, Dict +from typing import Any, Dict, Iterator, List, Optional import attr import jsonschema from signedjson.key import ( NACL_ED25519, + SigningKey, + VerifyKey, decode_signing_key_base64, decode_verify_key_bytes, generate_signing_key, @@ -31,6 +33,7 @@ ) from unpaddedbase64 import decode_base64 +from synapse.types import JsonDict from synapse.util.stringutils import random_string, random_string_with_symbols from ._base import Config, ConfigError @@ -81,14 +84,13 @@ logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True, auto_attribs=True) class TrustedKeyServer: - # string: name of the server. - server_name = attr.ib() + # name of the server. + server_name: str - # dict[str,VerifyKey]|None: map from key id to key object, or None to disable - # signature verification. - verify_keys = attr.ib(default=None) + # map from key id to key object, or None to disable signature verification. + verify_keys: Optional[Dict[str, VerifyKey]] = None class KeyConfig(Config): @@ -279,15 +281,15 @@ def generate_config_section( % locals() ) - def read_signing_keys(self, signing_key_path, name): + def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]: """Read the signing keys in the given path. Args: - signing_key_path (str) - name (str): Associated config key name + signing_key_path + name: Associated config key name Returns: - list[SigningKey] + The signing keys read from the given path. """ signing_keys = self.read_file(signing_key_path, name) @@ -296,7 +298,9 @@ def read_signing_keys(self, signing_key_path, name): except Exception as e: raise ConfigError("Error reading %s: %s" % (name, str(e))) - def read_old_signing_keys(self, old_signing_keys): + def read_old_signing_keys( + self, old_signing_keys: Optional[JsonDict] + ) -> Dict[str, VerifyKey]: if old_signing_keys is None: return {} keys = {} @@ -340,7 +344,7 @@ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: write_signing_keys(signing_key_file, (key,)) -def _perspectives_to_key_servers(config): +def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]: """Convert old-style 'perspectives' configs into new-style 'trusted_key_servers' Returns an iterable of entries to add to trusted_key_servers. @@ -402,7 +406,9 @@ def _perspectives_to_key_servers(config): } -def _parse_key_servers(key_servers, federation_verify_certificates): +def _parse_key_servers( + key_servers: List[Any], federation_verify_certificates: bool +) -> Iterator[TrustedKeyServer]: try: jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA) except jsonschema.ValidationError as e: @@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates): yield result -def _assert_keyserver_has_verify_keys(trusted_key_server): +def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None: if not trusted_key_server.verify_keys: raise ConfigError(INSECURE_NOTARY_ERROR) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 7ac82edb0ed1..1cc26e757812 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -22,10 +22,12 @@ @attr.s class MetricsFlags: - known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + known_servers: bool = attr.ib( + default=False, validator=attr.validators.instance_of(bool) + ) @classmethod - def all_off(cls): + def all_off(cls) -> "MetricsFlags": """ Instantiate the flags with all options set to off. """ diff --git a/synapse/config/server.py b/synapse/config/server.py index ba5b95426338..1de2dea9b024 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1257,7 +1257,7 @@ def add_arguments(parser: argparse.ArgumentParser) -> None: help="Turn on the twisted telnet manhole service on the given port.", ) - def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]: + def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]: """Reads the three durations for the GC min interval option, returning seconds.""" if durations is None: return None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4ca111618fe9..ffb316e4c011 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -132,7 +132,7 @@ def read_config(self, config: dict, config_dir_path: str, **kwargs): self.tls_certificate: Optional[crypto.X509] = None self.tls_private_key: Optional[crypto.PKey] = None - def read_certificate_from_disk(self): + def read_certificate_from_disk(self) -> None: """ Read the certificates and private key from disk. """ From 3b8872299aac25a7e3ee5a9e00564105aa6de237 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Dec 2021 13:16:01 -0500 Subject: [PATCH 014/157] Do not allow cross-room relations, per MSC2674. (#11516) --- changelog.d/11516.bugfix | 1 + synapse/events/utils.py | 11 +- synapse/rest/client/relations.py | 7 +- synapse/storage/databases/main/events.py | 8 +- synapse/storage/databases/main/relations.py | 36 ++++-- tests/rest/client/test_relations.py | 115 ++++++++++++++++++++ 6 files changed, 161 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11516.bugfix diff --git a/changelog.d/11516.bugfix b/changelog.d/11516.bugfix new file mode 100644 index 000000000000..22bba93671d7 --- /dev/null +++ b/changelog.d/11516.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 84ef69df679b..3da432c1df58 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -454,23 +454,26 @@ async def _injected_bundled_aggregations( return event_id = event.event_id + room_id = event.room_id # The bundled aggregations to include. aggregations = {} - annotations = await self.store.get_aggregation_groups_for_event(event_id) + annotations = await self.store.get_aggregation_groups_for_event( + event_id, room_id + ) if annotations.chunk: aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" + event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id, room_id) if edit: # If there is an edit replace the content, preserving existing @@ -503,7 +506,7 @@ async def _injected_bundled_aggregations( ( thread_count, latest_thread_event, - ) = await self.store.get_thread_summary(event_id) + ) = await self.store.get_thread_summary(event_id, room_id) if latest_thread_event: aggregations[RelationTypes.THREAD] = { # Don't bundle aggregations as this could recurse forever. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index fc4e6921c5e6..ffa37ef06c89 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -212,6 +212,7 @@ async def on_GET( pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, limit=limit, @@ -317,6 +318,7 @@ async def on_GET( pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, + room_id=room_id, event_type=event_type, limit=limit, from_token=from_token, @@ -383,7 +385,9 @@ async def on_GET( # This checks that a) the event exists and b) the user is allowed to # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -402,6 +406,7 @@ async def on_GET( result = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, aggregation_key=key, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4e528612eab7..f1f4ce5e0765 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1780,10 +1780,14 @@ def _handle_event_relations( ) if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + txn.call_after( + self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) + ) if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + txn.call_after( + self.store.get_thread_summary.invalidate, (parent_id, event.room_id) + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0a43acda07bb..3368a8b08488 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_relations_for_event( self, event_id: str, + room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, aggregation_key: Optional[str] = None, @@ -49,6 +50,7 @@ async def get_relations_for_event( Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. aggregation_key: Only fetch events with this aggregation key, if given. @@ -63,8 +65,8 @@ async def get_relations_for_event( the form `{"event_id": "..."}`. """ - where_clause = ["relates_to_id = ?"] - where_args: List[Union[str, int]] = [event_id] + where_clause = ["relates_to_id = ?", "room_id = ?"] + where_args: List[Union[str, int]] = [event_id, room_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -199,6 +201,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool: async def get_aggregation_groups_for_event( self, event_id: str, + room_id: str, event_type: Optional[str] = None, limit: int = 5, direction: str = "b", @@ -213,6 +216,7 @@ async def get_aggregation_groups_for_event( Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or @@ -225,8 +229,12 @@ async def get_aggregation_groups_for_event( `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] + where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] + where_args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] if event_type: where_clause.append("type = ?") @@ -288,7 +296,9 @@ def _get_aggregation_groups_for_event_txn( ) @cached() - async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + async def get_applicable_edit( + self, event_id: str, room_id: str + ) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. @@ -296,6 +306,7 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: Args: event_id: The original event ID + room_id: The original event's room ID Returns: The most recent edit, if any. @@ -317,13 +328,14 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: WHERE relates_to_id = ? AND relation_type = ? + AND edit.room_id = ? AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE)) + txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) row = txn.fetchone() if row: return row[0] @@ -340,13 +352,14 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: @cached() async def get_thread_summary( - self, event_id: str + self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: """Get the number of threaded replies, the senders of those replies, and the latest reply (if any) for the given event. Args: - event_id: The original event ID + event_id: Summarize the thread related to this event ID. + room_id: The room the event belongs to. Returns: The number of items in the thread and the most recent response, if any. @@ -363,12 +376,13 @@ def _get_thread_summary_txn( INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None @@ -378,11 +392,13 @@ def _get_thread_summary_txn( sql = """ SELECT COALESCE(COUNT(event_id), 0) FROM event_relations + INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) count = txn.fetchone()[0] # type: ignore[index] return count, latest_event_id diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 397c12c2a6c5..55f4f0b1d035 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -16,6 +16,7 @@ import itertools import urllib.parse from typing import Dict, List, Optional, Tuple +from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -23,6 +24,8 @@ from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import inject_event class RelationsTestCase(unittest.HomeserverTestCase): @@ -651,6 +654,118 @@ def test_aggregation_get_event_for_thread(self): }, ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_ignore_invalid_room(self): + """Test that we ignore invalid relations over federation.""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Disable the validation to pretend this came over federation. + with patch( + "synapse.handlers.message.EventCreationHandler._validate_event_relation", + new=lambda self, event: make_awaitable(None), + ): + # Generate a various relations from a different room. + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.reaction", + sender=self.user_id, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": parent_id, + "key": "A", + } + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "new_content": { + "body": "new content", + "msgtype": "m.text", + }, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": parent_id, + }, + }, + ) + ) + + # They should be ignored when fetching relations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And when fetching aggregations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And for bundled aggregations. + channel = self.make_request( + "GET", + f"/rooms/{room2}/event/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_edit(self): """Test that a simple edit works.""" From 9562f0c2f1bd3489bfbb64fddbbd21ed657b44dd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 10 Dec 2021 07:17:28 -0500 Subject: [PATCH 015/157] Ensure emails are canonicalized before fetching associated user. (#11547) This should fix pushers with an email in non-canonical form is used as the pushkey. --- changelog.d/11547.bugfix | 1 + synapse/push/pusherpool.py | 5 ++++- synapse/storage/databases/main/monthly_active_users.py | 3 ++- synapse/storage/databases/main/registration.py | 3 ++- tests/rest/admin/test_user.py | 3 ++- 5 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11547.bugfix diff --git a/changelog.d/11547.bugfix b/changelog.d/11547.bugfix new file mode 100644 index 000000000000..3950c4c8d30c --- /dev/null +++ b/changelog.d/11547.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created. diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 26735447a6f1..7912311d2401 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -27,6 +27,7 @@ from synapse.replication.http.push import ReplicationRemovePusherRestServlet from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -113,7 +114,9 @@ async def add_pusher( """ if kind == "email": - email_owner = await self.store.get_user_id_by_threepid("email", pushkey) + email_owner = await self.store.get_user_id_by_threepid( + "email", canonicalise_email(pushkey) + ) if email_owner != user_id: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index b5284e4f6783..3c98ef876f8c 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -18,6 +18,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached +from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -103,7 +104,7 @@ async def get_registered_reserved_users(self) -> List[str]: : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] + tp["medium"], canonicalise_email(tp["address"]) ) if user_id: users.append(user_id) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e1ddf0691646..86c34257168c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -856,7 +856,8 @@ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[s Args: medium: threepid medium e.g. email - address: threepid address e.g. me@example.com + address: threepid address e.g. me@example.com. This must already be + in canonical form. Returns: The user ID or None if no user id/threepid mapping exists diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 294d429ce12b..eea675991cbc 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1550,7 +1550,8 @@ def test_create_user_email_notif_for_new_users(self): # Create user body = { "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + # Note that the given email is not in canonical form. + "threepids": [{"medium": "email", "address": "Bob@bob.bob"}], } channel = self.make_request( From 86e7a6d16ee9ffe8f5e783ec8150405b13f878fa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 10 Dec 2021 14:13:23 +0000 Subject: [PATCH 016/157] Stop populating `state_events.prev_state` (#11558) this field is never read, so we may as well stop populating it. --- changelog.d/11558.misc | 1 + synapse/storage/databases/main/events.py | 4 ---- synapse/storage/schema/__init__.py | 5 ++++- 3 files changed, 5 insertions(+), 5 deletions(-) create mode 100644 changelog.d/11558.misc diff --git a/changelog.d/11558.misc b/changelog.d/11558.misc new file mode 100644 index 000000000000..7c334f17e007 --- /dev/null +++ b/changelog.d/11558.misc @@ -0,0 +1 @@ +Stop populating unused database column `state_events.prev_state`. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index f1f4ce5e0765..eed453d8360b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1410,10 +1410,6 @@ def event_dict(event): "state_key": event.state_key, } - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - state_values.append(vals) self.db_pool.simple_insert_many_txn( diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 50d08094d52c..2a3d47185ae5 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 66 # remember to update the list below when updating +SCHEMA_VERSION = 67 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -50,6 +50,9 @@ Changes in SCHEMA_VERSION = 66: - Queries on state_key columns are now disambiguated (ie, the codebase can handle the `events` table having a `state_key` column). + +Changes in SCHEMA_VERSION = 67: + - state_events.prev_state is no longer written to. """ From f0562183e732b066860fcb4b8e24f8a1988aea1a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 10 Dec 2021 15:02:33 +0000 Subject: [PATCH 017/157] skip some dict munging in event persistence (#11560) Create a new dict helper method `simple_insert_many_values_txn`, which takes raw row values, rather than {key=>value} dicts. This saves us a bunch of dict munging, and makes it easier to use generators rather than creating intermediate lists and dicts. --- changelog.d/11560.misc | 1 + synapse/storage/database.py | 59 +++++++++++- synapse/storage/databases/main/events.py | 114 ++++++++++++----------- 3 files changed, 114 insertions(+), 60 deletions(-) create mode 100644 changelog.d/11560.misc diff --git a/changelog.d/11560.misc b/changelog.d/11560.misc new file mode 100644 index 000000000000..eb968167f58a --- /dev/null +++ b/changelog.d/11560.misc @@ -0,0 +1 @@ +Minor efficiency improvements in event persistence. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0693d390064f..5552dd3c5c35 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -896,6 +896,9 @@ async def simple_insert_many( ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values should be preferred for new code. + Args: table: string giving the table name values: dict of new column names and values for them @@ -909,6 +912,9 @@ def simple_insert_many_txn( ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values_txn should be preferred for new code. + Args: txn: The transaction to use. table: string giving the table name @@ -933,23 +939,66 @@ def simple_insert_many_txn( if k != keys[0]: raise RuntimeError("All items must have the same keys") + return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals) + + async def simple_insert_many_values( + self, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + desc: str, + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + desc: description of the transaction, for logging and metrics + """ + await self.runInteraction( + desc, self.simple_insert_many_values_txn, table, keys, values + ) + + @staticmethod + def simple_insert_many_values_txn( + txn: LoggingTransaction, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + txn: The transaction to use. + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + """ + if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, # but it's only available on postgres. sql = "INSERT INTO %s (%s) VALUES ?" % ( table, - ", ".join(k for k in keys[0]), + ", ".join(k for k in keys), ) - txn.execute_values(sql, vals, fetch=False) + txn.execute_values(sql, values, fetch=False) else: sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), ) - txn.execute_batch(sql, vals) + txn.execute_batch(sql, values) async def simple_upsert( self, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index eed453d8360b..5184e6bf85f8 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -19,6 +19,7 @@ from typing import ( TYPE_CHECKING, Any, + Collection, Dict, Generator, Iterable, @@ -1319,14 +1320,13 @@ def _update_outliers_txn(self, txn, events_and_contexts): return [ec for ec in events_and_contexts if ec[0] not in to_remove] - def _store_event_txn(self, txn, events_and_contexts): + def _store_event_txn( + self, + txn: LoggingTransaction, + events_and_contexts: Collection[Tuple[EventBase, EventContext]], + ) -> None: """Insert new events into the event, event_json, redaction and state_events tables. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting """ if not events_and_contexts: @@ -1339,46 +1339,58 @@ def event_dict(event): d.pop("redacted_because", None) return d - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="event_json", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": json_encoder.encode( - event.internal_metadata.get_dict() - ), - "json": json_encoder.encode(event_dict(event)), - "format_version": event.format_version, - } + keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), + values=( + ( + event.event_id, + event.room_id, + json_encoder.encode(event.internal_metadata.get_dict()), + json_encoder.encode(event_dict(event)), + event.format_version, + ) for event, _ in events_and_contexts - ], + ), ) - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_values_txn( txn, table="events", - values=[ - { - "instance_name": self._instance_name, - "stream_ordering": event.internal_metadata.stream_ordering, - "topological_ordering": event.depth, - "depth": event.depth, - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "processed": True, - "outlier": event.internal_metadata.is_outlier(), - "origin_server_ts": int(event.origin_server_ts), - "received_ts": self._clock.time_msec(), - "sender": event.sender, - "contains_url": ( - "url" in event.content and isinstance(event.content["url"], str) - ), - } + keys=( + "instance_name", + "stream_ordering", + "topological_ordering", + "depth", + "event_id", + "room_id", + "type", + "processed", + "outlier", + "origin_server_ts", + "received_ts", + "sender", + "contains_url", + ), + values=( + ( + self._instance_name, + event.internal_metadata.stream_ordering, + event.depth, # topological_ordering + event.depth, # depth + event.event_id, + event.room_id, + event.type, + True, # processed + event.internal_metadata.is_outlier(), + int(event.origin_server_ts), + self._clock.time_msec(), + event.sender, + "url" in event.content and isinstance(event.content["url"], str), + ) for event, _ in events_and_contexts - ], + ), ) # If we're persisting an unredacted event we go and ensure @@ -1397,23 +1409,15 @@ def event_dict(event): ) txn.execute(sql + clause, [False] + args) - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, _ in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - state_values.append(vals) - - self.db_pool.simple_insert_many_txn( - txn, table="state_events", values=state_values + self.db_pool.simple_insert_many_values_txn( + txn, + table="state_events", + keys=("event_id", "room_id", "type", "state_key"), + values=( + (event.event_id, event.room_id, event.type, event.state_key) + for event, _ in events_and_contexts + if event.is_state() + ), ) def _store_rejected_events_txn(self, txn, events_and_contexts): From fd2dadb8152cbacf5395fe84e6392bde7ad45897 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 10 Dec 2021 19:19:48 +0000 Subject: [PATCH 018/157] Adjust _get_rooms_changed comments (#11550) C.f. https://github.com/matrix-org/synapse/pull/11494#pullrequestreview-827780886 --- changelog.d/11550.misc | 1 + synapse/handlers/sync.py | 51 +++++++++++++++++++++++----------------- 2 files changed, 30 insertions(+), 22 deletions(-) create mode 100644 changelog.d/11550.misc diff --git a/changelog.d/11550.misc b/changelog.d/11550.misc new file mode 100644 index 000000000000..d5577e0b6300 --- /dev/null +++ b/changelog.d/11550.misc @@ -0,0 +1 @@ +Fix an inaccurate and misleading comment in the `/sync` code. \ No newline at end of file diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 96f37e9f4204..bcd10cbb3051 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1662,20 +1662,20 @@ async def _get_rooms_changed( ) -> _RoomChanges: """Determine the changes in rooms to report to the user. - Ideally, we want to report all events whose stream ordering `s` lies in the - range `since_token < s <= now_token`, where the two tokens are read from the - sync_result_builder. + This function is a first pass at generating the rooms part of the sync response. + It determines which rooms have changed during the sync period, and categorises + them into four buckets: "knock", "invite", "join" and "leave". - If there are too many events in that range to report, things get complicated. - In this situation we return a truncated list of the most recent events, and - indicate in the response that there is a "gap" of omitted events. Additionally: + 1. Finds all membership changes for the user in the sync period (from + `since_token` up to `now_token`). + 2. Uses those to place the room in one of the four categories above. + 3. Builds a `_RoomChanges` struct to record this, and return that struct. - - we include a "state_delta", to describe the changes in state over the gap, - - we include all membership events applying to the user making the request, - even those in the gap. - - See the spec for the rationale: - https://spec.matrix.org/v1.1/client-server-api/#syncing + For rooms classified as "knock", "invite" or "leave", we just need to report + a single membership event in the eventual /sync response. For "join" we need + to fetch additional non-membership events, e.g. messages in the room. That is + more complicated, so instead we report an intermediary `RoomSyncResultBuilder` + struct, and leave the additional work to `_generate_room_entry`. The sync_result_builder is not modified by this function. """ @@ -1686,16 +1686,6 @@ async def _get_rooms_changed( assert since_token - # The spec - # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync - # notes that membership events need special consideration: - # - # > When a sync is limited, the server MUST return membership events for events - # > in the gap (between since and the start of the returned timeline), regardless - # > as to whether or not they are redundant. - # - # We fetch such events here, but we only seem to use them for categorising rooms - # as newly joined, newly left, invited or knocked. # TODO: we've already called this function and ran this query in # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. @@ -2009,6 +1999,23 @@ async def _generate_room_entry( """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. + Ideally, we want to report all events whose stream ordering `s` lies in the + range `since_token < s <= now_token`, where the two tokens are read from the + sync_result_builder. + + If there are too many events in that range to report, things get complicated. + In this situation we return a truncated list of the most recent events, and + indicate in the response that there is a "gap" of omitted events. Lots of this + is handled in `_load_filtered_recents`, but some of is handled in this method. + + Additionally: + - we include a "state_delta", to describe the changes in state over the gap, + - we include all membership events applying to the user making the request, + even those in the gap. + + See the spec for the rationale: + https://spec.matrix.org/v1.1/client-server-api/#syncing + Args: sync_result_builder ignored_users: Set of users ignored by user. From 8391bd6ab59387845bae77130dc0ca437c37fb8e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 10 Dec 2021 20:59:20 -0600 Subject: [PATCH 019/157] Test to ensure we share the same `state_group` across the whole historical batch (MSC2716) (#11487) Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 We did some work on making sure the `state_groups` were shared in https://github.com/matrix-org/synapse/pull/10975 --- changelog.d/11487.misc | 1 + tests/rest/client/test_room_batch.py | 180 +++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 changelog.d/11487.misc create mode 100644 tests/rest/client/test_room_batch.py diff --git a/changelog.d/11487.misc b/changelog.d/11487.misc new file mode 100644 index 000000000000..376b9078be87 --- /dev/null +++ b/changelog.d/11487.misc @@ -0,0 +1 @@ +Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py new file mode 100644 index 000000000000..721454c1875f --- /dev/null +++ b/tests/rest/client/test_room_batch.py @@ -0,0 +1,180 @@ +import logging +from typing import List, Tuple +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.appservice import ApplicationService +from synapse.rest import admin +from synapse.rest.client import login, register, room, room_batch +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +logger = logging.getLogger(__name__) + + +def _create_join_state_events_for_batch_send_request( + virtual_user_ids: List[str], + insert_time: int, +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Member, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "membership": "join", + "displayname": "display-name-for-%s" % (virtual_user_id,), + }, + "state_key": virtual_user_id, + } + for virtual_user_id in virtual_user_ids + ] + + +def _create_message_events_for_batch_send_request( + virtual_user_id: str, insert_time: int, count: int +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Message, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "msgtype": "m.text", + "body": "Historical %d" % (i), + EventContentFields.MSC2716_HISTORICAL: True, + }, + } + for i in range(count) + ] + + +class RoomBatchTestCase(unittest.HomeserverTestCase): + """Test importing batches of historical messages.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room_batch.register_servlets, + room.register_servlets, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.clock = clock + self.storage = hs.get_storage() + + self.virtual_user_id = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) + + def _create_test_room(self) -> Tuple[str, str, str, str]: + room_id = self.helper.create_room_as( + self.appservice.sender, tok=self.appservice.token + ) + + res_a = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "A", + }, + tok=self.appservice.token, + ) + event_id_a = res_a["event_id"] + + res_b = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "B", + }, + tok=self.appservice.token, + ) + event_id_b = res_b["event_id"] + + res_c = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "C", + }, + tok=self.appservice.token, + ) + event_id_c = res_c["event_id"] + + return room_id, event_id_a, event_id_b, event_id_c + + @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) + def test_same_state_groups_for_whole_historical_batch(self): + """Make sure that when using the `/batch_send` endpoint to import a + bunch of historical messages, it re-uses the same `state_group` across + the whole batch. This is an easy optimization to make sure we're getting + right because the state for the whole batch is contained in + `state_events_at_start` and can be shared across everything. + """ + + time_before_room = int(self.clock.time_msec()) + room_id, event_id_a, _, _ = self._create_test_room() + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" + % (room_id, event_id_a), + content={ + "events": _create_message_events_for_batch_send_request( + self.virtual_user_id, time_before_room, 3 + ), + "state_events_at_start": _create_join_state_events_for_batch_send_request( + [self.virtual_user_id], time_before_room + ), + }, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Get the historical event IDs that we just imported + historical_event_ids = channel.json_body["event_ids"] + self.assertEqual(len(historical_event_ids), 3) + + # Fetch the state_groups + state_group_map = self.get_success( + self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + ) + + # We expect all of the historical events to be using the same state_group + # so there should only be a single state_group here! + self.assertEqual( + len(state_group_map.keys()), + 1, + "Expected a single state_group to be returned by saw state_groups=%s" + % (state_group_map.keys(),), + ) From aa8708ebed74b03bdebd7e20ddf070c6fd620db1 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 10 Dec 2021 23:08:51 -0600 Subject: [PATCH 020/157] Allow events to be created with no `prev_events` (MSC2716) (#11243) The event still needs to have `auth_events` defined to be valid. Split out from https://github.com/matrix-org/synapse/pull/11114 --- changelog.d/11243.misc | 1 + synapse/handlers/message.py | 24 ++++++-- synapse/handlers/room_member.py | 3 +- tests/handlers/test_message.py | 103 ++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 7 deletions(-) create mode 100644 changelog.d/11243.misc diff --git a/changelog.d/11243.misc b/changelog.d/11243.misc new file mode 100644 index 000000000000..5ef7fe16d4c9 --- /dev/null +++ b/changelog.d/11243.misc @@ -0,0 +1 @@ +Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 87f671708c4e..38409fef38d9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -496,6 +496,7 @@ async def create_event( require_consent: bool = True, outlier: bool = False, historical: bool = False, + allow_no_prev_events: bool = False, depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ @@ -607,6 +608,7 @@ async def create_event( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, + allow_no_prev_events=allow_no_prev_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -882,6 +884,7 @@ async def create_new_client_event( prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + allow_no_prev_events: bool = False, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -912,6 +915,7 @@ async def create_new_client_event( full_state_ids_at_event = None if auth_event_ids is not None: # If auth events are provided, prev events must be also. + # prev_event_ids could be an empty array though. assert prev_event_ids is not None # Copy the full auth state before it stripped down @@ -943,14 +947,22 @@ async def create_new_client_event( else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - # we now ought to have some prev_events (unless it's a create event). - # - # do a quick sanity check here, rather than waiting until we've created the + # Do a quick sanity check here, rather than waiting until we've created the # event and then try to auth it (which fails with a somewhat confusing "No # create event in auth events") - assert ( - builder.type == EventTypes.Create or len(prev_event_ids) > 0 - ), "Attempting to create an event with no prev_events" + if allow_no_prev_events: + # We allow events with no `prev_events` but it better have some `auth_events` + assert ( + builder.type == EventTypes.Create + # Allow an event to have empty list of prev_event_ids + # only if it has auth_event_ids. + or auth_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids" + else: + # we now ought to have some prev_events (unless it's a create event). + assert ( + builder.type == EventTypes.Create or prev_event_ids + ), "Attempting to create a non-m.room.create event with no prev_events" event = await builder.build( prev_event_ids=prev_event_ids, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a6dbff637f54..447e3ce5713b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -658,7 +658,8 @@ async def update_membership_locked( if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - if prev_event_ids: + # An empty prev_events list is allowed as long as the auth_event_ids are present + if prev_event_ids is not None: return await self._local_membership_update( requester=requester, target=target, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 8a8d369faca1..5816295d8b97 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -23,6 +23,7 @@ from synapse.util.stringutils import random_string from tests import unittest +from tests.test_utils.event_injection import create_event logger = logging.getLogger(__name__) @@ -51,6 +52,24 @@ def prepare(self, reactor, clock, hs): self.requester = create_requester(self.user_id, access_token_id=self.token_id) + def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: + # Create a member event we can use as an auth_event + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.requester.user.to_string(), + state_key=self.requester.user.to_string(), + content={"membership": "join"}, + ) + ) + self.get_success( + self.persist_event_storage.persist_event(memberEvent, memberEventContext) + ) + + return memberEvent, memberEventContext + def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. @@ -156,6 +175,90 @@ def test_duplicated_txn_id_one_call(self): self.assertEqual(len(events), 2) self.assertEqual(events[0].event_id, events[1].event_id) + def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self): + """When we set allow_no_prev_events=True, should be able to create a + event without any prev_events (only auth_events). + """ + # Create a member event we can use as an auth_event + memberEvent, _ = self._create_and_persist_member_event() + + # Try to create the event with empty prev_events bit with some auth_events + event, _ = self.get_success( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + # Empty prev_events is the key thing we're testing here + prev_event_ids=[], + # But with some auth_events + auth_event_ids=[memberEvent.event_id], + # Allow no prev_events! + allow_no_prev_events=True, + ) + ) + self.assertIsNotNone(event) + + def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events( + self, + ): + """When we set allow_no_prev_events=False, shouldn't be able to create a + event without any prev_events even if it has auth_events. Expect an + exception to be raised. + """ + # Create a member event we can use as an auth_event + memberEvent, _ = self._create_and_persist_member_event() + + # Try to create the event with empty prev_events but with some auth_events + self.get_failure( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + # Empty prev_events is the key thing we're testing here + prev_event_ids=[], + # But with some auth_events + auth_event_ids=[memberEvent.event_id], + # We expect the test to fail because empty prev_events are not + # allowed here! + allow_no_prev_events=False, + ), + AssertionError, + ) + + def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events( + self, + ): + """When we set allow_no_prev_events=True, should be able to create a + event without any prev_events or auth_events. Expect an exception to be + raised. + """ + # Try to create the event with empty prev_events and empty auth_events + self.get_failure( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + prev_event_ids=[], + # The event should be rejected when there are no auth_events + auth_event_ids=[], + # Allow no prev_events! + allow_no_prev_events=True, + ), + AssertionError, + ) + class ServerAclValidationTestCase(unittest.HomeserverTestCase): servlets = [ From e5cdb9e2339e321e8a77a898d362d7fbc476303b Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 13 Dec 2021 15:39:43 +0000 Subject: [PATCH 021/157] Make `get_device` return None if the device doesn't exist rather than raising an exception. (#11565) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/11565.misc | 1 + synapse/handlers/auth.py | 4 +--- synapse/handlers/device.py | 10 ++++++---- synapse/rest/admin/devices.py | 2 ++ synapse/rest/client/devices.py | 6 ++++-- synapse/storage/databases/main/devices.py | 10 ++++++---- 6 files changed, 20 insertions(+), 13 deletions(-) create mode 100644 changelog.d/11565.misc diff --git a/changelog.d/11565.misc b/changelog.d/11565.misc new file mode 100644 index 000000000000..ddcafd32cbac --- /dev/null +++ b/changelog.d/11565.misc @@ -0,0 +1 @@ +Make `get_device` return `None` if the device doesn't exist rather than raising an exception. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 61607cf2bad7..84724b207c9d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -997,9 +997,7 @@ async def create_access_token_for_user_id( # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - try: - await self.store.get_device(user_id, device_id) - except StoreError: + if await self.store.get_device(user_id, device_id) is None: await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 82ee11e921e6..766542523218 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -106,10 +106,10 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict: Raises: errors.NotFoundError: if the device was not found """ - try: - device = await self.store.get_device(user_id, device_id) - except errors.StoreError: - raise errors.NotFoundError + device = await self.store.get_device(user_id, device_id) + if device is None: + raise errors.NotFoundError() + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) @@ -602,6 +602,8 @@ async def rehydrate_device( access_token, device_id ) old_device = await self.store.get_device(user_id, old_device_id) + if old_device is None: + raise errors.NotFoundError() await self.store.update_device(user_id, device_id, old_device["display_name"]) # can't call self.delete_device because that will clobber the # access token so call the storage layer directly diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 062a33d28d15..d9905ff560cb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -63,6 +63,8 @@ async def on_GET( device = await self.device_handler.get_device( target_user.to_string(), device_id ) + if device is None: + raise NotFoundError("No device found") return HTTPStatus.OK, device async def on_DELETE( diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8566dc5cb594..ad6fd6492baa 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api import errors +from synapse.api.errors import NotFoundError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -24,10 +25,9 @@ parse_json_object_from_request, ) from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict -from ._base import client_patterns, interactive_auth_handler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -116,6 +116,8 @@ async def on_GET( device = await self.device_handler.get_device( requester.user.to_string(), device_id ) + if device is None: + raise NotFoundError("No device found") return 200, device @interactive_auth_handler diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 838a2a6a3dd0..eff825dd2254 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -101,7 +101,9 @@ def count_devices_by_users_txn(txn, user_ids): "count_devices_by_users", count_devices_by_users_txn, user_ids ) - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -109,15 +111,15 @@ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - A dict containing the device information - Raises: - StoreError: if the device is not found + A dict containing the device information, or `None` if the device does not + exist. """ return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", + allow_none=True, ) async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: From 6da8591f2ef9597880ace89aaf434332dddaa711 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 13 Dec 2021 16:28:10 +0000 Subject: [PATCH 022/157] Add type hints to `synapse/storage/databases/main/account_data.py` (#11546) --- changelog.d/11546.misc | 1 + mypy.ini | 4 +- .../storage/databases/main/account_data.py | 93 ++++++++++++------- synapse/storage/databases/main/tags.py | 22 ++++- 4 files changed, 87 insertions(+), 33 deletions(-) create mode 100644 changelog.d/11546.misc diff --git a/changelog.d/11546.misc b/changelog.d/11546.misc new file mode 100644 index 000000000000..d451940bf216 --- /dev/null +++ b/changelog.d/11546.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index 1caf807e8505..4b2ddafd2042 100644 --- a/mypy.ini +++ b/mypy.ini @@ -25,7 +25,6 @@ exclude = (?x) ^( |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py - |synapse/storage/databases/main/account_data.py |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/e2e_room_keys.py @@ -181,6 +180,9 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.account_data] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.client_ips] disallow_untyped_defs = True diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index f8bec266ac41..32a553fdd7bd 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,15 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -34,13 +44,19 @@ logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_account_data_stream_id` which can be called in the initializer. - """ +class AccountDataWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - self._instance_name = hs.get_instance_name() + # `_can_write_to_account_data` indicates whether the current worker is allowed + # to write account data. A value of `True` implies that `_account_data_id_gen` + # is an `AbstractStreamIdGenerator` and not just a tracker. + self._account_data_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( @@ -61,8 +77,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): writers=hs.config.worker.writers.account_data, ) else: - self._can_write_to_account_data = True - # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # @@ -70,7 +84,8 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.account_data: + if self._instance_name in hs.config.worker.writers.account_data: + self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", @@ -90,8 +105,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): "AccountDataAndTagsChangeCache", account_max ) - super().__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -113,7 +126,9 @@ async def get_account_data_for_user( room_id string to per room account_data dicts. """ - def get_account_data_for_user_txn(txn): + def get_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: rows = self.db_pool.simple_select_list_txn( txn, "account_data", @@ -132,7 +147,7 @@ def get_account_data_for_user_txn(txn): ["room_id", "account_data_type", "content"], ) - by_room = {} + by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json(row["content"]) @@ -177,7 +192,9 @@ async def get_account_data_for_room( A dict of the room account_data """ - def get_account_data_for_room_txn(txn): + def get_account_data_for_room_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", @@ -207,7 +224,9 @@ async def get_account_data_for_room_and_type( The room account_data for that type, or None if there isn't any set. """ - def get_account_data_for_room_and_type_txn(txn): + def get_account_data_for_room_and_type_txn( + txn: LoggingTransaction, + ) -> Optional[JsonDict]: content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", @@ -243,14 +262,16 @@ async def get_updated_global_account_data( if last_id == current_id: return [] - def get_updated_global_account_data_txn(txn): + def get_updated_global_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn @@ -273,14 +294,16 @@ async def get_updated_room_account_data( if last_id == current_id: return [] - def get_updated_room_account_data_txn(txn): + def get_updated_room_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str]]: sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn @@ -299,7 +322,9 @@ async def get_updated_account_data_for_user( mapping from room_id string to per room account_data dicts. """ - def get_updated_account_data_for_user_txn(txn): + def get_updated_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" @@ -316,7 +341,7 @@ def get_updated_account_data_for_user_txn(txn): txn.execute(sql, (user_id, stream_id)) - account_data_by_room = {} + account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) @@ -353,12 +378,15 @@ async def ignored_by(self, user_id: str) -> Set[str]: ) ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: @@ -372,7 +400,8 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + + super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -389,6 +418,7 @@ async def add_account_data_to_room( The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -431,6 +461,7 @@ async def add_account_data_for_user( The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -452,7 +483,7 @@ async def add_account_data_for_user( def _add_account_data_for_user( self, - txn, + txn: LoggingTransaction, next_id: int, user_id: str, account_data_type: str, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 8f510de53d43..c8e508a910fb 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Dict, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Tuple, cast +from synapse.replication.tcp.streams import TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -204,6 +206,7 @@ async def add_tag_to_room( The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -230,6 +233,7 @@ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> in The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: sql = ( @@ -258,6 +262,7 @@ def _update_revision_txn( next_id: The the revision to advance to. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -287,6 +292,21 @@ def _update_revision_txn( # than the id that the client has. pass + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + class TagsStore(TagsWorkerStore): pass From 1abfb15f07d4f8119afcf908f9e1903e7feef371 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 13 Dec 2021 16:28:26 +0000 Subject: [PATCH 023/157] Add type hints to `synapse/storage/databases/main/end_to_end_keys.py` (#11551) --- changelog.d/11551.misc | 1 + mypy.ini | 4 +- synapse/storage/databases/main/__init__.py | 3 - .../storage/databases/main/end_to_end_keys.py | 211 ++++++++++++------ 4 files changed, 150 insertions(+), 69 deletions(-) create mode 100644 changelog.d/11551.misc diff --git a/changelog.d/11551.misc b/changelog.d/11551.misc new file mode 100644 index 000000000000..d451940bf216 --- /dev/null +++ b/changelog.d/11551.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index 4b2ddafd2042..a7b1f4eb64cb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,7 +28,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/e2e_room_keys.py - |synapse/storage/databases/main/end_to_end_keys.py |synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_push_actions.py |synapse/storage/databases/main/events_bg_updates.py @@ -189,6 +188,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.directory] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.end_to_end_keys] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.events_worker] disallow_untyped_defs = True diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9ff2d8d8c35a..065145c0d280 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -143,9 +143,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): ("device_lists_outbound_pokes", "stream_id"), ], ) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" - ) self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b06c1dc45b2d..57b5ffbad32b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,19 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + cast, +) import attr from canonicaljson import encode_canonical_json -from twisted.enterprise.adbapi import Connection - from synapse.api.constants import DeviceKeyAlgorithms from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -50,7 +63,12 @@ class DeviceKeyLookupResult: class EndToEndKeyBackgroundStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -62,8 +80,13 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer" ) -class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): +class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._allow_device_name_lookup_over_federation = ( @@ -124,7 +147,7 @@ async def get_e2e_device_keys_for_cs_api( # Build the result structure, un-jsonify the results, and add the # "unsigned" section - rv = {} + rv: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): @@ -195,6 +218,10 @@ async def get_e2e_device_keys_and_signatures( # add each cross-signing signature to the correct device in the result dict. for (user_id, key_id, device_id, signature) in cross_sigs_result: target_device_result = result[user_id][device_id] + # We've only looked up cross-signatures for non-deleted devices with key + # data. + assert target_device_result is not None + assert target_device_result.keys is not None target_device_signatures = target_device_result.keys.setdefault( "signatures", {} ) @@ -207,7 +234,11 @@ async def get_e2e_device_keys_and_signatures( return result def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False + self, + txn: LoggingTransaction, + query_list: Collection[Tuple[str, str]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: """Get information on devices from the database @@ -263,7 +294,7 @@ def _get_e2e_device_keys_txn( return result def _get_e2e_cross_signing_signatures_for_devices_txn( - self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] ) -> List[Tuple[str, str, str, str]]: """Get cross-signing signatures for a given list of devices @@ -289,7 +320,17 @@ def _get_e2e_cross_signing_signatures_for_devices_txn( ) txn.execute(signature_sql, signature_query_params) - return txn.fetchall() + return cast( + List[ + Tuple[ + str, + str, + str, + str, + ] + ], + txn.fetchall(), + ) async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str] @@ -335,7 +376,7 @@ async def add_e2e_one_time_keys( new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn): + def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("new_keys", new_keys) @@ -375,7 +416,7 @@ async def count_e2e_one_time_keys( A mapping from algorithm to number of keys for that algorithm. """ - def _count_e2e_one_time_keys(txn): + def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]: sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ?" @@ -421,7 +462,11 @@ async def set_e2e_fallback_keys( ) def _set_e2e_fallback_keys_txn( - self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + fallback_keys: JsonDict, ) -> None: # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad @@ -483,7 +528,7 @@ async def get_e2e_unused_fallback_key_types( async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """Returns a user's cross-signing key. Args: @@ -504,7 +549,7 @@ async def get_e2e_cross_signing_key( return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id): + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -517,7 +562,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id): ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -531,32 +576,35 @@ async def _get_bare_e2e_cross_signing_keys_bulk( their user ID will map to None. """ - return await self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, ) + # The `Optional` comes from the `@cachedList` decorator. + return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + def _get_bare_e2e_cross_signing_keys_bulk_txn( self, - txn: Connection, + txn: LoggingTransaction, user_ids: Iterable[str], - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Dict[str, JsonDict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_ids (list[str]): the users whose keys are being requested + txn: db connection + user_ids: the users whose keys are being requested Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, their user - ID will not be in the dict. + Mapping from user ID to key type to key data. + If a user's cross-signing keys were not found, their user ID will not be in + the dict. """ - result = {} + result: Dict[str, Dict[str, JsonDict]] = {} for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( @@ -596,43 +644,48 @@ def _get_bare_e2e_cross_signing_keys_bulk_txn( user_id = row["user_id"] key_type = row["keytype"] key = db_to_json(row["keydata"]) - user_info = result.setdefault(user_id, {}) - user_info[key_type] = key + user_keys = result.setdefault(user_id, {}) + user_keys[key_type] = key return result def _get_e2e_cross_signing_signatures_txn( self, - txn: Connection, - keys: Dict[str, Dict[str, dict]], + txn: LoggingTransaction, + keys: Dict[str, Optional[Dict[str, JsonDict]]], from_user_id: str, - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing signatures made by a user on a set of keys. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - keys (dict[str, dict[str, dict]]): a map of user ID to key type to - key data. This dict will be modified to add signatures. - from_user_id (str): fetch the signatures made by this user + txn: db connection + keys: a map of user ID to key type to key data. + This dict will be modified to add signatures. + from_user_id: fetch the signatures made by this user Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. The return value will be the same as the keys argument, - with the modifications included. + Mapping from user ID to key type to key data. + The return value will be the same as the keys argument, with the + modifications included. """ # find out what cross-signing keys (a.k.a. devices) we need to get # signatures for. This is a map of (user_id, device_id) to key type # (device_id is the key's public part). - devices = {} + devices: Dict[Tuple[str, str], str] = {} - for user_id, user_info in keys.items(): - if user_info is None: + for user_id, user_keys in keys.items(): + if user_keys is None: continue - for key_type, key in user_info.items(): + for key_type, key in user_keys.items(): device_id = None for k in key["keys"].values(): device_id = k + # `key` ought to be a `CrossSigningKey`, whose .keys property is a + # dictionary with a single entry: + # "algorithm:base64_public_key": "base64_public_key" + # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing + assert isinstance(device_id, str) devices[(user_id, device_id)] = key_type for batch in batch_iter(devices.keys(), size=100): @@ -656,15 +709,20 @@ def _get_e2e_cross_signing_signatures_txn( # and add the signatures to the appropriate keys for row in rows: - key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] + key_id: str = row["key_id"] + target_user_id: str = row["target_user_id"] + target_device_id: str = row["target_device_id"] key_type = devices[(target_user_id, target_device_id)] # We need to copy everything, because the result may have come # from the cache. dict.copy only does a shallow copy, so we # need to recursively copy the dicts that will be modified. - user_info = keys[target_user_id] = keys[target_user_id].copy() - target_user_key = user_info[key_type] = user_info[key_type].copy() + user_keys = keys[target_user_id] + # `user_keys` cannot be `None` because we only fetched signatures for + # users with keys + assert user_keys is not None + user_keys = keys[target_user_id] = user_keys.copy() + + target_user_key = user_keys[key_type] = user_keys[key_type].copy() if "signatures" in target_user_key: signatures = target_user_key["signatures"] = target_user_key[ "signatures" @@ -683,7 +741,7 @@ def _get_e2e_cross_signing_signatures_txn( async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, dict]]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -741,7 +799,9 @@ async def get_all_user_signature_changes_for_remotes( if last_id == current_id: return [], current_id, False - def _get_all_user_signature_changes_for_remotes_txn(txn): + def _get_all_user_signature_changes_for_remotes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, from_user_id AS user_id FROM user_signature_stream @@ -785,7 +845,7 @@ async def claim_e2e_one_time_keys( @trace def _claim_e2e_one_time_key_simple( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. @@ -825,7 +885,7 @@ def _claim_e2e_one_time_key_simple( @trace def _claim_e2e_one_time_key_returning( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. @@ -860,7 +920,7 @@ def _claim_e2e_one_time_key_returning( key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json - results = {} + results: Dict[str, Dict[str, Dict[str, str]]] = {} for user_id, device_id, algorithm in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that @@ -930,6 +990,18 @@ def _claim_e2e_one_time_key_returning( class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) + async def set_e2e_device_keys( self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict ) -> bool: @@ -937,7 +1009,7 @@ async def set_e2e_device_keys( or the keys were already in the database. """ - def _set_e2e_device_keys_txn(txn): + def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("time_now", time_now) @@ -973,7 +1045,7 @@ def _set_e2e_device_keys_txn(txn): ) async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: - def delete_e2e_keys_by_device_txn(txn): + def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: log_kv( { "message": "Deleting keys for device", @@ -1012,17 +1084,24 @@ def delete_e2e_keys_by_device_txn(txn): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): + def _set_e2e_cross_signing_key_txn( + self, + txn: LoggingTransaction, + user_id: str, + key_type: str, + key: JsonDict, + stream_id: int, + ) -> None: """Set a user's cross-signing key. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user to set the signing key for - key_type (str): the type of key that is being set: either 'master' + txn: db connection + user_id: the user to set the signing key for + key_type: the type of key that is being set: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - key (dict): the key data - stream_id (int) + key: the key data + stream_id """ # the 'key' dict will look something like: # { @@ -1075,13 +1154,15 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id) txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - async def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key( + self, user_id: str, key_type: str, key: JsonDict + ) -> None: """Set a user's cross-signing key. Args: - user_id (str): the user to set the user-signing key for - key_type (str): the type of cross-signing key to set - key (dict): the key data + user_id: the user to set the user-signing key for + key_type: the type of cross-signing key to set + key: the key data """ async with self._cross_signing_id_gen.get_next() as stream_id: From 5305a5e88144828419249fd9e4c5198d92276a44 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 13 Dec 2021 17:05:00 +0000 Subject: [PATCH 024/157] Type hint the constructors of the data store classes (#11555) --- changelog.d/11555.misc | 1 + synapse/replication/slave/storage/_base.py | 9 +++++-- .../replication/slave/storage/client_ips.py | 9 +++++-- synapse/replication/slave/storage/devices.py | 9 +++++-- synapse/replication/slave/storage/events.py | 9 +++++-- .../replication/slave/storage/filtering.py | 9 +++++-- synapse/replication/slave/storage/groups.py | 9 +++++-- synapse/storage/_base.py | 13 +++++---- synapse/storage/database.py | 2 +- synapse/storage/databases/main/__init__.py | 9 +++++-- synapse/storage/databases/main/appservice.py | 10 ++++--- synapse/storage/databases/main/cache.py | 9 +++++-- .../storage/databases/main/censor_events.py | 13 +++++++-- synapse/storage/databases/main/client_ips.py | 22 ++++++++++++--- synapse/storage/databases/main/deviceinbox.py | 7 ++++- synapse/storage/databases/main/devices.py | 22 ++++++++++++--- .../databases/main/event_federation.py | 20 +++++++++++--- .../databases/main/event_push_actions.py | 20 +++++++++++--- synapse/storage/databases/main/events.py | 9 ++++--- .../databases/main/events_bg_updates.py | 8 +++++- .../storage/databases/main/group_server.py | 10 ++++--- synapse/storage/databases/main/lock.py | 14 +++++++--- synapse/storage/databases/main/metrics.py | 9 +++++-- .../databases/main/monthly_active_users.py | 20 +++++++++++--- synapse/storage/databases/main/presence.py | 6 ++--- synapse/storage/databases/main/push_rule.py | 9 +++++-- synapse/storage/databases/main/receipts.py | 13 +++++++-- synapse/storage/databases/main/room.py | 27 ++++++++++++++++--- synapse/storage/databases/main/roommember.py | 23 +++++++++++++--- synapse/storage/databases/main/search.py | 20 +++++++++++--- synapse/storage/databases/main/state.py | 27 ++++++++++++++++--- synapse/storage/databases/main/stats.py | 9 +++++-- synapse/storage/databases/main/stream.py | 8 +++++- .../storage/databases/main/transactions.py | 13 +++++++-- .../storage/databases/main/user_directory.py | 11 +++++--- 35 files changed, 351 insertions(+), 87 deletions(-) create mode 100644 changelog.d/11555.misc diff --git a/changelog.d/11555.misc b/changelog.d/11555.misc new file mode 100644 index 000000000000..d451940bf216 --- /dev/null +++ b/changelog.d/11555.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 7ecb446e7c78..7644146dbadb 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -27,7 +27,12 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen: Optional[ diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 61cd7e522800..bc888ce1a871 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.lrucache import LruCache @@ -25,7 +25,12 @@ class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.client_ip_last_seen: LruCache[tuple, int] = LruCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 0a582960896d..a2aff75b7075 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -17,7 +17,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -27,7 +27,12 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 63ed50caa5eb..50e7379e8301 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, @@ -58,7 +58,12 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 90284c202d55..4d185e2b56c7 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore @@ -24,7 +24,12 @@ class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 497e16c69e6a..9d90e26375f0 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -17,7 +17,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -26,7 +26,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 3056e64ff570..7967011afdc0 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,10 +17,8 @@ from abc import ABCMeta from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union -from synapse.storage.database import LoggingTransaction # noqa: F401 -from synapse.storage.database import make_in_list_sql_clause # noqa: F401 -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import get_domain_from_id from synapse.util import json_decoder @@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 5552dd3c5c35..3b44e6469cd8 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -175,7 +175,7 @@ def commit(self) -> None: def rollback(self) -> None: self.conn.rollback() - def __enter__(self) -> "Connection": + def __enter__(self) -> "LoggingDatabaseConnection": self.conn.__enter__() return self diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 065145c0d280..716b25dd349a 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -129,7 +129,12 @@ class DataStore( LockStore, SessionStore, ): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 4a883dc16647..92c95a41d793 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -24,9 +24,8 @@ from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder @@ -58,7 +57,12 @@ def _make_exclusive_regex( class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self.services_cache = load_appservices( hs.hostname, hs.config.appservice.app_service_config_files ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 36e8422fc63b..0024348067d5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -25,7 +25,7 @@ EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -41,7 +41,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 0f56e10220d0..fd3fc298b37a 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -18,7 +18,11 @@ from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util import json_encoder @@ -31,7 +35,12 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if ( diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 1dc7f0ebe346..8b0c614ecef7 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -26,7 +26,6 @@ make_tuple_comparison_clause, ) from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore -from synapse.storage.types import Connection from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache @@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict): class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -394,7 +398,12 @@ def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int: class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -532,7 +541,12 @@ def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): # (user_id, access_token, ip,) -> last_seen self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index ab8766c75b62..b410eefdc71f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index eff825dd2254..393259998857 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,6 +38,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -61,7 +62,12 @@ class DeviceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -953,7 +959,12 @@ def _prune_txn(txn): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -1085,7 +1096,12 @@ def _txn(txn): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 9580a4078538..2287f1cc6877 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -24,7 +24,11 @@ from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine @@ -62,7 +66,12 @@ def __init__(self, room_id: str): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -1514,7 +1523,12 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3efdd0c920f6..eacff3e432d8 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -20,7 +20,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -82,7 +86,12 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn @@ -910,7 +919,12 @@ def _remove_old_push_actions_before_txn( class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5184e6bf85f8..81e67ece5535 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -41,10 +41,13 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.types import Connection from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id @@ -95,7 +98,7 @@ def __init__( hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore", - db_conn: Connection, + db_conn: LoggingDatabaseConnection, ): self.hs = hs self.db_pool = db diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index c88fd35e7f3a..9b36941fecb8 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -23,6 +23,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -83,7 +84,12 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index bb621df0ddb6..3f6086050bb2 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -19,8 +19,7 @@ from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import JsonDict from synapse.util import json_encoder @@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict): class GroupServerWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): database.updates.register_background_index_update( update_name="local_group_updates_index", index_name="local_group_updates_stream_id_index", diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index a540f7fb2681..bedacaf0d745 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -20,8 +20,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.types import Connection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.util import Clock from synapse.util.stringutils import random_string @@ -54,7 +57,12 @@ class LockStore(SQLBaseStore): `last_renewed_ts` column with the current time. """ - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._reactor = hs.get_reactor() diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index d901933ae4f2..3bb21958d1e9 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -19,7 +19,7 @@ from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) @@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 3c98ef876f8c..65b7e307e146 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -16,7 +16,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + make_in_list_sql_clause, +) from synapse.util.caches.descriptors import cached from synapse.util.threepids import canonicalise_email @@ -31,7 +35,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -213,7 +222,12 @@ def _reap_users(txn, reserved_users): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._mau_stats_only = hs.config.server.mau_stats_only diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index cc0eebdb4606..02d534ae4523 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -17,7 +17,7 @@ from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator @@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 3b63267395c3..e01c94930aed 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -20,7 +20,7 @@ from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore @@ -81,7 +81,12 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 9c5625c8bbb8..bf0b903af2fc 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -32,7 +32,11 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -47,7 +51,12 @@ class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d694d852d53..28c4b65bbd4c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -24,7 +24,11 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.search import SearchStore from synapse.storage.types import Cursor from synapse.types import JsonDict, ThirdPartyInstanceID @@ -72,7 +76,12 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1050,7 +1059,12 @@ class _BackgroundUpdates: class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1435,7 +1449,12 @@ def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 6b2a8d06a67c..cda80d651160 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -37,7 +37,7 @@ wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -64,7 +64,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) # Used by `_get_joined_hosts` to ensure only one thing mutates the cache @@ -985,7 +990,12 @@ def _is_local_host_in_room_ignoring_users_txn(txn): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile @@ -1135,7 +1145,12 @@ def _background_current_state_membership_txn(txn, last_processed_room): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 7fe233767f76..f87acfb86604 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -20,7 +20,11 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -105,7 +109,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if not hs.config.server.enable_search: @@ -358,7 +367,12 @@ def reindex_search_txn(txn): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index fa2c3b1feb91..4bc044fb1606 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -22,7 +22,11 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter @@ -56,7 +60,12 @@ def __len__(self): class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -349,7 +358,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -536,5 +550,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 5d7b59d861c9..9020e0976c7c 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -24,7 +24,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import StoreError -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -96,7 +96,12 @@ class UserSortOrder(Enum): class StatsStore(StateDeltasStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 57aab5525937..9488fd509463 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -49,6 +49,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_in_list_sql_clause, ) @@ -339,7 +340,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 162282255232..54b41513ee13 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -22,7 +22,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -71,7 +75,12 @@ class DestinationRetryTimings: class TransactionWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index e98a45b6af60..0f9b8575d3a5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -32,11 +32,14 @@ from synapse.server import HomeServer from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.types import Connection from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) @@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__( self, database: DatabasePool, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, hs: "HomeServer", ) -> None: super().__init__(database, db_conn, hs) From eb39da6782b57c939450839097f32a14cba3ebfc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 13 Dec 2021 12:55:07 -0500 Subject: [PATCH 025/157] Move HTML parsing to a separate file for URL previews. (#11566) * Splits the logic for parsing HTML from the resource handling code. * Fix a circular import in the oEmbed code (which uses the HTML parsing code). * Renames some of the HTML parsing methods to: * Make it clear which methods are "internal" to the module. * Clarify what the methods do. --- changelog.d/11566.misc | 1 + synapse/rest/media/v1/oembed.py | 5 +- synapse/rest/media/v1/preview_html.py | 397 ++++++++++++++++++ synapse/rest/media/v1/preview_url_resource.py | 383 +---------------- tests/rest/media/v1/test_url_preview.py | 1 + tests/test_preview.py | 46 +- 6 files changed, 432 insertions(+), 401 deletions(-) create mode 100644 changelog.d/11566.misc create mode 100644 synapse/rest/media/v1/preview_html.py diff --git a/changelog.d/11566.misc b/changelog.d/11566.misc new file mode 100644 index 000000000000..c48e73cd486c --- /dev/null +++ b/changelog.d/11566.misc @@ -0,0 +1 @@ +Split the HTML parsing code from the URL preview resource code. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 2a59552c20a3..cce1527ed9fb 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -17,6 +17,7 @@ import attr +from synapse.rest.media.v1.preview_html import parse_html_description from synapse.types import JsonDict from synapse.util import json_decoder @@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> if video_urls: open_graph_response["og:video"] = video_urls[0] - from synapse.rest.media.v1.preview_url_resource import _calc_description - - description = _calc_description(tree) + description = parse_html_description(tree) if description: open_graph_response["og:description"] = description diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py new file mode 100644 index 000000000000..30b067dd4271 --- /dev/null +++ b/synapse/rest/media/v1/preview_html.py @@ -0,0 +1,397 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import itertools +import logging +import re +from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union +from urllib import parse as urlparse + +if TYPE_CHECKING: + from lxml import etree + +logger = logging.getLogger(__name__) + +_charset_match = re.compile( + br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) +_xml_encoding_match = re.compile( + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I +) +_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) + + +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + +def _get_html_media_encodings( + body: bytes, content_type: Optional[str] +) -> Iterable[str]: + """ + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. + + The precedence used for finding a character encoding is: + + 1. tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Check if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding + + # TODO Support + + # Check if it has an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Finally, fallback to UTF-8, then windows-1252. + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback + + +def decode_body( + body: bytes, uri: str, content_type: Optional[str] = None +) -> Optional["etree.Element"]: + """ + This uses lxml to parse the HTML document. + + Args: + body: The HTML document, as bytes. + uri: The URI used to download the body. + content_type: The Content-Type header. + + Returns: + The parsed HTML body, or None if an error occurred during processed. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return None + + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. + for encoding in _get_html_media_encodings(body, content_type): + try: + body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + + from lxml import etree + + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding=encoding) + + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body, parser) + + +def parse_html_to_open_graph( + tree: "etree.Element", media_uri: str +) -> Dict[str, Optional[str]]: + """ + Parse the HTML document into an Open Graph response. + + This uses lxml to search the HTML document for Open Graph data (or + synthesizes it from the document). + + Args: + tree: The parsed HTML document. + media_url: The URI used to download the body. + + Returns: + The Open Graph response as a dictionary. + """ + + # if we see any image URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those + # URLs to avoid DoSing the server) + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og: Dict[str, Optional[str]] = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + if "content" in tag.attrib: + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} + og[tag.attrib["property"]] = tag.attrib["content"] + + # TODO: grab article: meta tags too, e.g.: + + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + if "og:title" not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + if title and title[0].text is not None: + og["og:title"] = title[0].text.strip() + else: + og["og:title"] = None + + if "og:image" not in og: + # TODO: extract a favicon failing all else + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + ) + if meta_image: + og["og:image"] = rebase_url(meta_image[0], media_uri) + else: + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) + if not images: + images = tree.xpath("//img[@src]") + if images: + og["og:image"] = images[0].attrib["src"] + + if "og:description" not in og: + meta_description = tree.xpath( + "//*/meta" + "[translate(@name, 'DESCRIPTION', 'description')='description']" + "/@content" + ) + if meta_description: + og["og:description"] = meta_description[0] + else: + og["og:description"] = parse_html_description(tree) + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) + og["og:description"] = summarize_paragraphs([og["og:description"]]) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def parse_html_description(tree: "etree.Element") -> Optional[str]: + """ + Calculate a text description based on an HTML document. + + Grabs any text nodes which are inside the tag, unless they are within + an HTML5 semantic markup tag (
,