Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix slow performance of /logout in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. #12056

Merged
merged 5 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12056.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens.
18 changes: 16 additions & 2 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,8 @@ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
user_id=row[1],
device_id=row[2],
next_token_id=row[3],
has_next_refresh_token_been_refreshed=row[4],
# SQLite returns 0 or 1 for false/true, so convert to a bool.
has_next_refresh_token_been_refreshed=bool(row[4]),
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
expiry_ts=row[6],
Expand All @@ -1697,12 +1698,15 @@ async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None
Set the successor of a refresh token, removing the existing successor
if any.

This also deletes the predecessor refresh and access tokens,
since they cannot be valid anymore.

Args:
token_id: ID of the refresh token to update.
next_token_id: ID of its successor.
"""

def _replace_refresh_token_txn(txn) -> None:
def _replace_refresh_token_txn(txn: LoggingTransaction) -> None:
# First check if there was an existing refresh token
old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
txn,
Expand All @@ -1728,6 +1732,16 @@ def _replace_refresh_token_txn(txn) -> None:
{"id": old_next_token_id},
)

# Delete the previous refresh token, since we only want to keep the
# last 2 refresh tokens in the database.
# (The predecessor of the latest refresh token is still useful in
# case the refresh was interrupted and the client re-uses the old
# one.)
# This cascades to delete the associated access token.
self.db_pool.simple_delete_txn(
txn, "refresh_tokens", {"next_token_id": token_id}
)

await self.db_pool.runInteraction(
"replace_refresh_token", _replace_refresh_token_txn
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

-- next_token_id is a foreign key reference, so previously required a table scan
-- when a row in the referenced table was deleted.
-- As it was self-referential and cascaded deletes, this led to O(t*n) time to
-- delete a row, where t: number of rows in the table and n: number of rows in
-- the ancestral 'chain' of access tokens.
--
-- This index is partial since we only require it for rows which reference
-- another.
-- Performance was tested to be the same regardless of whether the index was
-- full or partial, but a partial index can be smaller.
CREATE INDEX refresh_tokens_next_token_id
ON refresh_tokens(next_token_id)
WHERE next_token_id IS NOT NULL;
93 changes: 91 additions & 2 deletions tests/rest/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import Optional, Union
from typing import Optional, Tuple, Union

from twisted.internet.defer import succeed

import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, register
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID

from tests import unittest
Expand Down Expand Up @@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
auth.register_servlets,
account.register_servlets,
login.register_servlets,
logout.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
register.register_servlets,
]
Expand Down Expand Up @@ -984,3 +986,90 @@ def test_refresh_token_invalidation(self):
self.assertEqual(
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
)

def test_many_token_refresh(self):
"""
If a refresh is performed many times during a session, there shouldn't be
extra 'cruft' built up over time.

This test was written specifically to troubleshoot a case where logout
was very slow if a lot of refreshes had been performed for the session.
"""

def _refresh(refresh_token: str) -> Tuple[str, str]:
"""
Performs one refresh, returning the next refresh token and access token.
"""
refresh_response = self.use_refresh_token(refresh_token)
self.assertEqual(
refresh_response.code, HTTPStatus.OK, refresh_response.result
)
return (
refresh_response.json_body["refresh_token"],
refresh_response.json_body["access_token"],
)

def _table_length(table_name: str) -> int:
"""
Helper to get the size of a table, in rows.
For testing only; trivially vulnerable to SQL injection.
"""

def _txn(txn: LoggingTransaction) -> int:
txn.execute(f"SELECT COUNT(1) FROM {table_name}")
row = txn.fetchone()
# Query is infallible
assert row is not None
return row[0]

return self.get_success(
self.hs.get_datastores().main.db_pool.runInteraction(
"_table_length", _txn
)
)

# Before we log in, there are no access tokens.
self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0)

body = {
"type": "m.login.password",
"user": "test",
"password": self.user_pass,
"refresh_token": True,
}
login_response = self.make_request(
"POST",
"/_matrix/client/v3/login",
body,
)
self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)

access_token = login_response.json_body["access_token"]
refresh_token = login_response.json_body["refresh_token"]

# Now that we have logged in, there should be one access token and one
# refresh token
self.assertEqual(_table_length("access_tokens"), 1)
self.assertEqual(_table_length("refresh_tokens"), 1)

for _ in range(5):
refresh_token, access_token = _refresh(refresh_token)

# After 5 sequential refreshes, there should only be the latest two
# refresh/access token pairs.
# (The last one is preserved because it's in use!
# The one before that is preserved because it can still be used to
# replace the last token pair, in case of e.g. a network interruption.)
self.assertEqual(_table_length("access_tokens"), 2)
self.assertEqual(_table_length("refresh_tokens"), 2)

logout_response = self.make_request(
"POST", "/_matrix/client/v3/logout", {}, access_token=access_token
)
self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result)

# Now that we have logged in, there should be no access token
# and no refresh token
self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0)