diff --git a/changelog.d/17558.misc b/changelog.d/17558.misc new file mode 100644 index 00000000000..cfa8089a810 --- /dev/null +++ b/changelog.d/17558.misc @@ -0,0 +1 @@ +Speed up responding to media requests. diff --git a/synapse/http/server.py b/synapse/http/server.py index 0d0c610b284..211795dc396 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -74,7 +74,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.opentracing import active_span, start_active_span, trace_servlet -from synapse.types import ISynapseReactor from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.cancellation import is_function_cancellable @@ -869,8 +868,7 @@ def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes: with start_active_span("encode_json_response"): span = active_span() - reactor: ISynapseReactor = request.reactor # type: ignore - json_str = await defer_to_thread(reactor, encode, span) + json_str = await defer_to_thread(request.reactor, encode, span) _write_bytes_to_request(request, json_str) diff --git a/synapse/http/site.py b/synapse/http/site.py index af169ba51e6..8bf63edd362 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -658,7 +658,7 @@ def __init__( ) self.site_tag = site_tag - self.reactor = reactor + self.reactor: ISynapseReactor = reactor assert config.http_options is not None proxied = config.http_options.x_forwarded diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 1b268ce4d42..21f334339bf 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -22,12 +22,14 @@ import logging import os +import threading import urllib from abc import ABC, abstractmethod from types import TracebackType from typing import ( TYPE_CHECKING, Awaitable, + BinaryIO, Dict, Generator, List, @@ -37,15 +39,19 @@ ) import attr +from zope.interface import implementer +from twisted.internet import interfaces +from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender +from twisted.python.failure import Failure from twisted.web.server import Request from synapse.api.errors import Codes, cs_error from synapse.http.server import finish_request, respond_with_json from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.types import ISynapseReactor from synapse.util import Clock from synapse.util.stringutils import is_ascii @@ -138,7 +144,7 @@ async def respond_with_file( add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + await ThreadedFileSender(request.reactor).beginFileTransfer(f, request) finish_request(request) else: @@ -601,3 +607,132 @@ def _parseparam(s: bytes) -> Generator[bytes, None, None]: f = s[:end] yield f.strip() s = s[end:] + + +@implementer(interfaces.IPushProducer) +class ThreadedFileSender: + """ + A producer that sends the contents of a file to a consumer, reading from the + file on a thread. + + This works by spawning a loop in a threadpool that repeatedly reads from the + file and sends it to the consumer. The main thread communicates with the + loop via two `threading.Event`, which controls when to start/pause reading + and when to terminate. + """ + + # How much data to read in one go. + CHUNK_SIZE = 2**14 + + # How long we wait for the consumer to be ready again before aborting the + # read. + TIMEOUT_SECONDS = 90.0 + + def __init__(self, reactor: ISynapseReactor) -> None: + self.reactor = reactor + + self.file: Optional[BinaryIO] = None + self.deferred: "Deferred[None]" = Deferred() + self.consumer: Optional[interfaces.IConsumer] = None + + # Signals if the thread should keep reading/sending data. Set means + # continue, clear means pause. + self.wakeup_event = threading.Event() + + # Signals if the thread should terminate, e.g. because the consumer has + # gone away. Both this and `wakeup_event` should be set to terminate the + # loop (otherwise the thread will block on `wakeup_event`). + self.stop_event = threading.Event() + + def beginFileTransfer( + self, file: BinaryIO, consumer: interfaces.IConsumer + ) -> "Deferred[None]": + """ + Begin transferring a file + """ + self.file = file + self.consumer = consumer + + self.consumer.registerProducer(self, True) + + # We set the wakeup signal as we should start producing immediately. + self.wakeup_event.set() + defer_to_thread(self.reactor, self._on_thread_read_loop) + + return make_deferred_yieldable(self.deferred) + + def resumeProducing(self) -> None: + """interfaces.IPushProducer""" + self.wakeup_event.set() + + def pauseProducing(self) -> None: + """interfaces.IPushProducer""" + self.wakeup_event.clear() + + def stopProducing(self) -> None: + """interfaces.IPushProducer""" + + # Terminate the thread loop. + self.wakeup_event.set() + self.stop_event.set() + + if not self.deferred.called: + self.deferred.errback(Exception("Consumer asked us to stop producing")) + + def _on_thread_read_loop(self) -> None: + """This is the loop that happens on a thread.""" + + try: + while not self.stop_event.is_set(): + # We wait for the producer to signal that the consumer wants + # more data (or we should abort) + if not self.wakeup_event.is_set(): + ret = self.wakeup_event.wait(self.TIMEOUT_SECONDS) + if not ret: + raise Exception("Timed out waiting to resume") + + # Check if we were woken up so that we abort the download + if self.stop_event.is_set(): + return + + # The file should always have been set before we get here. + assert self.file is not None + + chunk = self.file.read(self.CHUNK_SIZE) + if not chunk: + return + + self.reactor.callFromThread(self._write, chunk) + + except Exception: + self.reactor.callFromThread(self._error, Failure()) + finally: + self.reactor.callFromThread(self._finish) + + def _write(self, chunk: bytes) -> None: + """Called from the thread to write a chunk of data""" + if self.consumer: + self.consumer.write(chunk) + + def _error(self, failure: Failure) -> None: + """Called from the thread when there was a fatal error""" + if self.consumer: + self.consumer.unregisterProducer() + self.consumer = None + + if not self.deferred.called: + self.deferred.errback(failure) + + def _finish(self) -> None: + """Called from the thread when it finishes (either on success or + failure).""" + if self.file: + self.file.close() + self.file = None + + if self.consumer: + self.consumer.unregisterProducer() + self.consumer = None + + if not self.deferred.called: + self.deferred.callback(None) diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index 2a106bb0eba..e06273c92f4 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -49,19 +49,15 @@ from twisted.internet import interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender from synapse.api.errors import NotFoundError -from synapse.logging.context import ( - defer_to_thread, - make_deferred_yieldable, - run_in_background, -) +from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.opentracing import start_active_span, trace, trace_with_opname +from synapse.media._base import ThreadedFileSender from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer -from ..types import JsonDict +from ..types import ISynapseReactor, JsonDict from ._base import FileInfo, Responder from .filepath import MediaFilePaths @@ -213,7 +209,7 @@ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) + return FileResponder(self.reactor, open(local_path, "rb")) logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: @@ -336,12 +332,13 @@ class FileResponder(Responder): is closed when finished streaming. """ - def __init__(self, open_file: IO): + def __init__(self, reactor: ISynapseReactor, open_file: BinaryIO): + self.reactor = reactor self.open_file = open_file def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - FileSender().beginFileTransfer(self.open_file, consumer) + return ThreadedFileSender(self.reactor).beginFileTransfer( + self.open_file, consumer ) def __exit__( diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py index 06e5d27a53a..355df999d29 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py @@ -145,6 +145,7 @@ class FileStorageProviderBackend(StorageProvider): def __init__(self, hs: "HomeServer", config: str): self.hs = hs + self.reactor = hs.get_reactor() self.cache_directory = hs.config.media.media_store_path self.base_directory = config @@ -165,7 +166,7 @@ async def store_file(self, path: str, file_info: FileInfo) -> None: shutil_copyfile: Callable[[str, str], str] = shutil.copyfile with start_active_span("shutil_copyfile"): await defer_to_thread( - self.hs.get_reactor(), + self.reactor, shutil_copyfile, primary_fname, backup_fname, @@ -177,7 +178,7 @@ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): - return FileResponder(open(backup_fname, "rb")) + return FileResponder(self.reactor, open(backup_fname, "rb")) return None diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index ef6aa8ccf54..3380315b274 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -259,6 +259,7 @@ def __init__( media_storage: MediaStorage, ): self.hs = hs + self.reactor = hs.get_reactor() self.media_repo = media_repo self.media_storage = media_storage self.store = hs.get_datastores().main @@ -373,7 +374,7 @@ async def select_or_generate_local_thumbnail( await respond_with_multipart_responder( self.hs.get_clock(), request, - FileResponder(open(file_path, "rb")), + FileResponder(self.reactor, open(file_path, "rb")), media_info, ) else: