Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/modules/media_repository_callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,33 @@ If multiple modules implement this callback, they will be considered in order. I
returns `True`, Synapse falls through to the next one. The value of the first callback that
returns `False` will be used. If this happens, Synapse will not call any of the subsequent
implementations of this callback.

### `get_media_upload_limits_for_user`

_First introduced in Synapse v1.137.0_

```python
async def get_media_upload_limits_for_user(user_id: str, size: int) -> Optional[List[MediaUploadLimit]]
```

**<span style="color:red">
Caution: This callback is currently experimental . The method signature or behaviour
may change without notice.
</span>**

Called when processing a request to store content in the media repository.

The arguments passed to this callback are:

* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request.

If the callback returns a list then it will be used as the limits to be applied to the request.

If an empty list is returned then no limits are applied.

If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.

If no module returns a non-`None` value then the default [media upload limits config](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#media_upload_limits) will be used.
23 changes: 18 additions & 5 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,13 @@ def __init__(self, hs: "HomeServer"):

# We get the media upload limits and sort them in descending order of
# time period, so that we can apply some optimizations.
self.media_upload_limits = hs.config.media.media_upload_limits
self.media_upload_limits.sort(
self.default_media_upload_limits = hs.config.media.media_upload_limits
self.default_media_upload_limits.sort(
key=lambda limit: limit.time_period_ms, reverse=True
)

self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository

def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media",
Expand Down Expand Up @@ -340,13 +342,24 @@ async def create_or_update_content(

# Check that the user has not exceeded any of the media upload limits.

# Use limits from module API if provided
media_upload_limits = (
await self.media_repository_callbacks.get_media_upload_limits_for_user(
auth_user.to_string()
)
)

# Otherwise use the default limits from config
if media_upload_limits is None:
# Note: the media upload limits are sorted so larger time periods are
# first.
media_upload_limits = self.default_media_upload_limits

# This is the total size of media uploaded by the user in the last
# `time_period_ms` milliseconds, or None if we haven't checked yet.
uploaded_media_size: Optional[int] = None

# Note: the media upload limits are sorted so larger time periods are
# first.
for limit in self.media_upload_limits:
for limit in media_upload_limits:
# We only need to check the amount of media uploaded by the user in
# this latest (smaller) time period if the amount of media uploaded
# in a previous (larger) time period is above the limit.
Expand Down
7 changes: 7 additions & 0 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.config import ConfigError
from synapse.config.repository import MediaUploadLimit
from synapse.events import EventBase
from synapse.events.presence_router import (
GET_INTERESTED_USERS_CALLBACK,
Expand Down Expand Up @@ -94,6 +95,7 @@
)
from synapse.module_api.callbacks.media_repository_callbacks import (
GET_MEDIA_CONFIG_FOR_USER_CALLBACK,
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK,
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK,
)
from synapse.module_api.callbacks.ratelimit_callbacks import (
Expand Down Expand Up @@ -205,6 +207,7 @@
"RoomAlias",
"UserProfile",
"RatelimitOverride",
"MediaUploadLimit",
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -462,13 +465,17 @@ def register_media_repository_callbacks(
is_user_allowed_to_upload_media_of_size: Optional[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = None,
get_media_upload_limits_for_user: Optional[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = None,
) -> None:
"""Registers callbacks for media repository capabilities.
Added in Synapse v1.132.0.
"""
return self._callbacks.media_repository.register_callbacks(
get_media_config_for_user=get_media_config_for_user,
is_user_allowed_to_upload_media_of_size=is_user_allowed_to_upload_media_of_size,
get_media_upload_limits_for_user=get_media_upload_limits_for_user,
)

def register_third_party_rules_callbacks(
Expand Down
33 changes: 33 additions & 0 deletions synapse/module_api/callbacks/media_repository_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional

from synapse.config.repository import MediaUploadLimit
from synapse.types import JsonDict
from synapse.util.async_helpers import delay_cancellation
from synapse.util.metrics import Measure
Expand All @@ -28,6 +29,10 @@

IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK = Callable[[str, int], Awaitable[bool]]

GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK = Callable[
[str], Awaitable[Optional[List[MediaUploadLimit]]]
]


class MediaRepositoryModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
Expand All @@ -39,13 +44,19 @@ def __init__(self, hs: "HomeServer") -> None:
self._is_user_allowed_to_upload_media_of_size_callbacks: List[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = []
self._get_media_upload_limits_for_user_callbacks: List[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = []

def register_callbacks(
self,
get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None,
is_user_allowed_to_upload_media_of_size: Optional[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = None,
get_media_upload_limits_for_user: Optional[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = None,
) -> None:
"""Register callbacks from module for each hook."""
if get_media_config_for_user is not None:
Expand All @@ -56,6 +67,11 @@ def register_callbacks(
is_user_allowed_to_upload_media_of_size
)

if get_media_upload_limits_for_user is not None:
self._get_media_upload_limits_for_user_callbacks.append(
get_media_upload_limits_for_user
)

async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]:
for callback in self._get_media_config_for_user_callbacks:
with Measure(
Expand Down Expand Up @@ -83,3 +99,20 @@ async def is_user_allowed_to_upload_media_of_size(
return res

return True

async def get_media_upload_limits_for_user(
self, user_id: str
) -> Optional[List[MediaUploadLimit]]:
for callback in self._get_media_upload_limits_for_user_callbacks:
with Measure(
self.clock,
name=f"{callback.__module__}.{callback.__qualname__}",
server_name=self.server_name,
):
res: Optional[List[MediaUploadLimit]] = await delay_cancellation(
callback(user_id)
)
if res is not None: # to allow [] to be returned meaning no limit
return res

return None
132 changes: 132 additions & 0 deletions tests/rest/client/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@

from synapse.api.errors import HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.config._base import Config
from synapse.config.oembed import OEmbedEndpointConfig
from synapse.http.client import MultipartResponse
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
from synapse.module_api import MediaUploadLimit
from synapse.rest import admin
from synapse.rest.client import login, media
from synapse.server import HomeServer
Expand Down Expand Up @@ -2967,3 +2969,133 @@ def test_over_weekly_limit(self) -> None:
# This will succeed as the weekly limit has reset
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)


class MediaUploadLimitsModuleOverrides(unittest.HomeserverTestCase):
"""
This test case simulates a homeserver with media upload limits being overridden by the module API.
"""

servlets = [
media.register_servlets,
login.register_servlets,
admin.register_servlets,
]

def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()

self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path

provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}

config["media_storage_providers"] = [provider_config]

# default limits to use are the limits that we are testing
config["media_upload_limits"] = [
{"time_period": "1d", "max_size": "1K"},
{"time_period": "1w", "max_size": "3K"},
]

return self.setup_test_homeserver(config=config)

async def _get_media_upload_limits_for_user(
self,
user_id: str,
) -> Optional[List[MediaUploadLimit]]:
# user1 has custom limits
if user_id == self.user1:
return [
MediaUploadLimit(
time_period_ms=Config.parse_duration("1d"), max_bytes=5000
),
MediaUploadLimit(
time_period_ms=Config.parse_duration("1w"), max_bytes=15000
),
]
# user2 has no limits
if user_id == self.user2:
return []
# otherwise use default
return None

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
self.store = hs.get_datastores().main
self.user1 = self.register_user("user1", "pass")
self.tok1 = self.login("user1", "pass")
self.user2 = self.register_user("user2", "pass")
self.tok2 = self.login("user2", "pass")
self.register_user("user3", "pass")
self.tok3 = self.login("user3", "pass")
self.hs.get_module_api().register_media_repository_callbacks(
get_media_upload_limits_for_user=self._get_media_upload_limits_for_user
)

def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources

def upload_media(self, size: int, tok: str) -> FakeChannel:
"""Helper to upload media of a given size with a given token."""
return self.make_request(
"POST",
"/_matrix/media/v3/upload",
content=b"0" * size,
access_token=tok,
shorthand=False,
content_type=b"text/plain",
custom_headers=[("Content-Length", str(size))],
)

def test_upload_under_limit(self) -> None:
"""Test that uploading media under the limit works."""

# User 1 uploads 100 bytes
channel = self.upload_media(100, self.tok1)
self.assertEqual(channel.code, 200)

# User 2 (unlimited) uploads 100 bytes
channel = self.upload_media(100, self.tok2)
self.assertEqual(channel.code, 200)

# User 3 (default) uploads 100 bytes
channel = self.upload_media(100, self.tok3)
self.assertEqual(channel.code, 200)

def test_uses_custom_limit(self) -> None:
"""Test that uploading media over the module provided daily limit fails."""

# User 1 uploads 3000 bytes
channel = self.upload_media(3000, self.tok1)
self.assertEqual(channel.code, 200)
# User 1 attempts to upload 4000 bytes taking it over the limit
channel = self.upload_media(4000, self.tok1)
self.assertEqual(channel.code, 400)

def test_uses_unlimited(self) -> None:
"""Test that unlimited user is not limited when module returns []."""
# User 2 uploads 10000 bytes which is over the default limit
channel = self.upload_media(10000, self.tok2)
self.assertEqual(channel.code, 200)

def test_uses_defaults(self) -> None:
"""Test that the default limits are applied when module returned None."""
# User 3 uploads 500 bytes
channel = self.upload_media(500, self.tok3)
self.assertEqual(channel.code, 200)
# User 3 uploads 800 bytes which is over the limit
channel = self.upload_media(800, self.tok3)
self.assertEqual(channel.code, 400)