Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make RateLimiter class check for ratelimit overrides (#9711)
Browse files Browse the repository at this point in the history
This should fix a class of bug where we forget to check if e.g. the appservice shouldn't be ratelimited.

We also check the `ratelimit_override` table to check if the user has ratelimiting disabled. That table is really only meant to override the event sender ratelimiting, so we don't use any values from it (as they might not make sense for different rate limits), but we do infer that if ratelimiting is disabled for the user we should disabled all ratelimits.

Fixes #9663
  • Loading branch information
erikjohnston authored Mar 30, 2021
1 parent 3a446c2 commit 963f430
Show file tree
Hide file tree
Showing 16 changed files with 241 additions and 154 deletions.
1 change: 1 addition & 0 deletions changelog.d/9711.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix recently added ratelimits to correctly honour the application service `rate_limited` flag.
100 changes: 55 additions & 45 deletions synapse/api/ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Hashable, Optional, Tuple

from synapse.api.errors import LimitExceededError
from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock

Expand All @@ -31,10 +32,13 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited.
"""

def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
def __init__(
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
self.store = store

# A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
Expand All @@ -46,55 +50,27 @@ def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]

def can_requester_do_action(
self,
requester: Requester,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the requester perform the action?
Args:
requester: The requester to key off when rate limiting. The user property
will be used.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
Returns:
A tuple containing:
* A bool indicating if they can perform the action now
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0

return self.can_do_action(
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
)

def can_do_action(
async def can_do_action(
self,
key: Hashable,
requester: Optional[Requester],
key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
Checks if the user has ratelimiting disabled in the database by looking
for null/zero values in the `ratelimit_override` table. (Non-zero
values aren't honoured, as they're specific to the event sending
ratelimiter, rather than all ratelimiters)
Args:
key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc.
requester: The requester that is doing the action, if any. Used to check
if the user has ratelimits disabled in the database.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Expand All @@ -109,6 +85,30 @@ def can_do_action(
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
if key is None:
if not requester:
raise ValueError("Must supply at least one of `requester` or `key`")

key = requester.user.to_string()

if requester:
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0

# Check if ratelimiting has been disabled for the user.
#
# Note that we don't use the returned rate/burst count, as the table
# is specifically for the event sending ratelimiter. Instead, we
# only use it to (somewhat cheekily) infer whether the user should
# be subject to any rate limiting or not.
override = await self.store.get_ratelimit_for_user(
requester.authenticated_entity
)
if override and not override.messages_per_second:
return True, -1.0

# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
Expand Down Expand Up @@ -175,18 +175,27 @@ def _prune_message_counts(self, time_now_s: int):
else:
del self.actions[key]

def ratelimit(
async def ratelimit(
self,
key: Hashable,
requester: Optional[Requester],
key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
):
"""Checks if an action can be performed. If not, raises a LimitExceededError
Checks if the user has ratelimiting disabled in the database by looking
for null/zero values in the `ratelimit_override` table. (Non-zero
values aren't honoured, as they're specific to the event sending
ratelimiter, rather than all ratelimiters)
Args:
key: An arbitrary key used to classify an action
requester: The requester that is doing the action, if any. Used to check for
if the user has ratelimits disabled.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Expand All @@ -201,7 +210,8 @@ def ratelimit(
"""
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()

allowed, time_allowed = self.can_do_action(
allowed, time_allowed = await self.can_do_action(
requester,
key,
rate_hz=rate_hz,
burst_count=burst_count,
Expand Down
5 changes: 4 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,7 @@ def __init__(self, hs: "HomeServer"):

# A rate limiter for incoming room key requests per origin.
self._room_key_request_rate_limiter = Ratelimiter(
store=hs.get_datastore(),
clock=self.clock,
rate_hz=self.config.rc_key_requests.per_second,
burst_count=self.config.rc_key_requests.burst_count,
Expand Down Expand Up @@ -930,7 +931,9 @@ async def on_edu(self, edu_type: str, origin: str, content: dict):
# the limit, drop them.
if (
edu_type == EduTypes.RoomKeyRequest
and not self._room_key_request_rate_limiter.can_do_action(origin)
and not await self._room_key_request_rate_limiter.can_do_action(
None, origin
)
):
return

Expand Down
14 changes: 5 additions & 9 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ def __init__(self, hs: "HomeServer"):

# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
clock=self.clock, rate_hz=0, burst_count=0
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = self.hs.config.rc_message

# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
Expand Down Expand Up @@ -91,11 +92,6 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False):
if app_service is not None:
return # do not ratelimit app service senders

# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return

messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count

Expand All @@ -113,11 +109,11 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False):
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash
self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
else:
# Override rate and burst count per-user
self.request_ratelimiter.ratelimit(
user_id,
await self.request_ratelimiter.ratelimit(
requester,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
Expand Down
24 changes: 14 additions & 10 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(self, hs: "HomeServer"):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
Expand All @@ -248,6 +249,7 @@ def __init__(self, hs: "HomeServer"):

# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
Expand Down Expand Up @@ -352,7 +354,7 @@ async def validate_user_via_ui_auth(
requester_user_id = requester.user.to_string()

# Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)

# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
Expand All @@ -373,7 +375,9 @@ def get_new_session_data() -> JsonDict:
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
await self._failed_uia_attempts_ratelimiter.can_do_action(
requester,
)
raise

# find the completed login type
Expand Down Expand Up @@ -982,8 +986,8 @@ async def validate_login(
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
(medium, address), update=False
await self._failed_login_attempts_ratelimiter.ratelimit(
None, (medium, address), update=False
)

# Check for login providers that support 3pid login types
Expand Down Expand Up @@ -1016,8 +1020,8 @@ async def validate_login(
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
(medium, address)
await self._failed_login_attempts_ratelimiter.can_do_action(
None, (medium, address)
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)

Expand All @@ -1039,8 +1043,8 @@ async def validate_login(

# Check if we've hit the failed ratelimit (but don't update it)
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
await self._failed_login_attempts_ratelimiter.ratelimit(
None, qualified_user_id.lower(), update=False
)

try:
Expand All @@ -1051,8 +1055,8 @@ async def validate_login(
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
qualified_user_id.lower()
await self._failed_login_attempts_ratelimiter.can_do_action(
None, qualified_user_id.lower()
)
raise

Expand Down
5 changes: 3 additions & 2 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self, hs: "HomeServer"):
)

self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
Expand Down Expand Up @@ -191,8 +192,8 @@ async def send_device_message(
if (
message_type == EduTypes.RoomKeyRequest
and user_id != sender_user_id
and self._ratelimiter.can_do_action(
(sender_user_id, requester.device_id)
and await self._ratelimiter.can_do_action(
requester, (sender_user_id, requester.device_id)
)
):
continue
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,7 +1711,7 @@ async def on_invite_request(
member_handler = self.hs.get_room_member_handler()
# We don't rate limit based on room ID, as that should be done by
# sending server.
member_handler.ratelimit_invite(None, event.state_key)
await member_handler.ratelimit_invite(None, None, event.state_key)

# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
Expand Down
12 changes: 9 additions & 3 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@ def __init__(self, hs):

# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)

def ratelimit_request_token_requests(
async def ratelimit_request_token_requests(
self,
request: SynapseRequest,
medium: str,
Expand All @@ -85,8 +87,12 @@ def ratelimit_request_token_requests(
address: The actual threepid ID, e.g. the phone number or email address
"""

self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
await self._3pid_validation_ratelimiter_ip.ratelimit(
None, (medium, request.getClientIP())
)
await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address)
)

async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def register_user(
Raises:
SynapseError if there was a problem registering.
"""
self.check_registration_ratelimit(address)
await self.check_registration_ratelimit(address)

result = await self.spam_checker.check_registration_for_spam(
threepid,
Expand Down Expand Up @@ -583,7 +583,7 @@ def check_user_id_not_appservice_exclusive(
errcode=Codes.EXCLUSIVE,
)

def check_registration_ratelimit(self, address: Optional[str]) -> None:
async def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Expand All @@ -597,7 +597,7 @@ def check_registration_ratelimit(self, address: Optional[str]) -> None:
if not address:
return

self.ratelimiter.ratelimit(address)
await self.ratelimiter.ratelimit(None, address)

async def register_with_store(
self,
Expand Down
Loading

0 comments on commit 963f430

Please sign in to comment.