Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Make RateLimiter class check for ratelimit overrides #9711

Merged
merged 10 commits into from
Mar 30, 2021
1 change: 1 addition & 0 deletions changelog.d/9711.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix recently added ratelimits to correctly honour the application service `rate_limited` flag.
53 changes: 41 additions & 12 deletions synapse/api/ratelimiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Hashable, Optional, Tuple

from synapse.api.errors import LimitExceededError
from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock

Expand All @@ -31,10 +32,13 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited.
"""

def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
def __init__(
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
self.store = store

# A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
Expand All @@ -46,7 +50,7 @@ def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]

def can_requester_do_action(
async def can_requester_do_action(
self,
requester: Requester,
rate_hz: Optional[float] = None,
Expand All @@ -73,17 +77,18 @@ def can_requester_do_action(
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0

return self.can_do_action(
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
return await self.can_do_action(
requester,
requester.user.to_string(),
rate_hz,
burst_count,
update,
_time_now_s,
)

def can_do_action(
async def can_do_action(
self,
requester: Optional[Requester],
key: Hashable,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
Expand All @@ -93,6 +98,8 @@ def can_do_action(
"""Can the entity (e.g. user or IP address) perform the action?

Args:
requester: The requester that is doing the action, if any. Used to check for
ratelimit overrides.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc.
rate_hz: The long term number of actions that can be performed in a second.
Expand All @@ -109,6 +116,24 @@ def can_do_action(
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
if requester:
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0

# Check if ratelimiting has been disabled for the user.
#
# Note that we don't use the returned rate/burst count, as the table
# is specifically for the event sending ratelimiter. Instead, we
# only use it to (somewhat cheekily) infer whether the user should
# be subject to any rate limiting or not.
override = await self.store.get_ratelimit_for_user(
requester.authenticated_entity
)
if override and not override.messages_per_second:
return True, -1.0

# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
Expand Down Expand Up @@ -175,8 +200,9 @@ def _prune_message_counts(self, time_now_s: int):
else:
del self.actions[key]

def ratelimit(
async def ratelimit(
self,
requester: Optional[Requester],
key: Hashable,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
Expand All @@ -186,6 +212,8 @@ def ratelimit(
"""Checks if an action can be performed. If not, raises a LimitExceededError

Args:
requester: The requester that is doing the action, if any. Used to check for
ratelimit overrides.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
key: An arbitrary key used to classify an action
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
Expand All @@ -201,7 +229,8 @@ def ratelimit(
"""
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()

allowed, time_allowed = self.can_do_action(
allowed, time_allowed = await self.can_do_action(
requester,
key,
rate_hz=rate_hz,
burst_count=burst_count,
Expand Down
5 changes: 4 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,7 @@ def __init__(self, hs: "HomeServer"):

# A rate limiter for incoming room key requests per origin.
self._room_key_request_rate_limiter = Ratelimiter(
store=hs.get_datastore(),
clock=self.clock,
rate_hz=self.config.rc_key_requests.per_second,
burst_count=self.config.rc_key_requests.burst_count,
Expand Down Expand Up @@ -930,7 +931,9 @@ async def on_edu(self, edu_type: str, origin: str, content: dict):
# the limit, drop them.
if (
edu_type == EduTypes.RoomKeyRequest
and not self._room_key_request_rate_limiter.can_do_action(origin)
and not await self._room_key_request_rate_limiter.can_do_action(
None, origin
)
):
return

Expand Down
15 changes: 7 additions & 8 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ def __init__(self, hs: "HomeServer"):

# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
clock=self.clock, rate_hz=0, burst_count=0
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = self.hs.config.rc_message

# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
Expand Down Expand Up @@ -91,11 +92,6 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False):
if app_service is not None:
return # do not ratelimit app service senders

# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return

messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count

Expand All @@ -113,10 +109,13 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False):
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash
self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
await self.admin_redaction_ratelimiter.ratelimit(
requester, user_id, update=update
)
else:
# Override rate and burst count per-user
self.request_ratelimiter.ratelimit(
await self.request_ratelimiter.ratelimit(
requester,
user_id,
rate_hz=messages_per_second,
burst_count=burst_count,
Expand Down
26 changes: 16 additions & 10 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(self, hs: "HomeServer"):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
Expand All @@ -248,6 +249,7 @@ def __init__(self, hs: "HomeServer"):

# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
Expand Down Expand Up @@ -352,7 +354,9 @@ async def validate_user_via_ui_auth(
requester_user_id = requester.user.to_string()

# Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
await self._failed_uia_attempts_ratelimiter.ratelimit(
requester, requester_user_id, update=False
)
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
Expand All @@ -373,7 +377,9 @@ def get_new_session_data() -> JsonDict:
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
await self._failed_uia_attempts_ratelimiter.can_do_action(
requester, requester_user_id
)
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
raise

# find the completed login type
Expand Down Expand Up @@ -982,8 +988,8 @@ async def validate_login(
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
(medium, address), update=False
await self._failed_login_attempts_ratelimiter.ratelimit(
None, (medium, address), update=False
)

# Check for login providers that support 3pid login types
Expand Down Expand Up @@ -1016,8 +1022,8 @@ async def validate_login(
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
(medium, address)
await self._failed_login_attempts_ratelimiter.can_do_action(
None, (medium, address)
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)

Expand All @@ -1039,8 +1045,8 @@ async def validate_login(

# Check if we've hit the failed ratelimit (but don't update it)
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
await self._failed_login_attempts_ratelimiter.ratelimit(
None, qualified_user_id.lower(), update=False
)

try:
Expand All @@ -1051,8 +1057,8 @@ async def validate_login(
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
qualified_user_id.lower()
await self._failed_login_attempts_ratelimiter.can_do_action(
None, qualified_user_id.lower()
)
raise

Expand Down
5 changes: 3 additions & 2 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self, hs: "HomeServer"):
)

self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
Expand Down Expand Up @@ -191,8 +192,8 @@ async def send_device_message(
if (
message_type == EduTypes.RoomKeyRequest
and user_id != sender_user_id
and self._ratelimiter.can_do_action(
(sender_user_id, requester.device_id)
and await self._ratelimiter.can_do_action(
requester, (sender_user_id, requester.device_id)
)
):
continue
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,7 +1711,7 @@ async def on_invite_request(
member_handler = self.hs.get_room_member_handler()
# We don't rate limit based on room ID, as that should be done by
# sending server.
member_handler.ratelimit_invite(None, event.state_key)
await member_handler.ratelimit_invite(None, None, event.state_key)

# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
Expand Down
12 changes: 9 additions & 3 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@ def __init__(self, hs):

# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)

def ratelimit_request_token_requests(
async def ratelimit_request_token_requests(
self,
request: SynapseRequest,
medium: str,
Expand All @@ -85,8 +87,12 @@ def ratelimit_request_token_requests(
address: The actual threepid ID, e.g. the phone number or email address
"""

self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
await self._3pid_validation_ratelimiter_ip.ratelimit(
None, (medium, request.getClientIP())
)
await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address)
)

async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def register_user(
Raises:
SynapseError if there was a problem registering.
"""
self.check_registration_ratelimit(address)
await self.check_registration_ratelimit(address)

result = await self.spam_checker.check_registration_for_spam(
threepid,
Expand Down Expand Up @@ -583,7 +583,7 @@ def check_user_id_not_appservice_exclusive(
errcode=Codes.EXCLUSIVE,
)

def check_registration_ratelimit(self, address: Optional[str]) -> None:
async def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address

Expand All @@ -597,7 +597,7 @@ def check_registration_ratelimit(self, address: Optional[str]) -> None:
if not address:
return

self.ratelimiter.ratelimit(address)
await self.ratelimiter.ratelimit(None, address)

async def register_with_store(
self,
Expand Down
Loading