diff --git a/docs/configuration.rst b/docs/configuration.rst index daef85d3b..30814499e 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -140,8 +140,8 @@ Define a secure connection to an Elastic Cloud Serverless project:: storage ~~~~~~~ -This section defines how client is configured to transfer corpus files. The main -advantages of this downloader implementation are: +This section defines how client is configured to transfer corpus files. The main advantages of this downloader +implementation are: * It supports configuring multiple mirror URLs. The client will load balance between these URLs giving priority to those with the lower latency. In case of failures the client will download files from the original URL. @@ -166,11 +166,17 @@ Configuration options are: Here is an example of valid value for http(s) adapter:: [storage] - storage.adapters = esrally.storage._http:HTTPAdapter + storage.adapters = esrally.storage.http:HTTPAdapter,esrally.storage.aws:S3Adapter + At this point in time ``esrally.storage.http:HTTPAdapter`` and ``esrally.storage.aws:S3Adapter`` are the only + known ``Adapter`` implementations intended for public use and they are both enabled by default. So it is required + to edit this option for special customizations (like for example remove one of them or adding a new + custom implementation for accessing a special infrastructure server). - At this point in time `esrally.storage._http:HTTPAdapter` is the only existing `Adapter` implementations intended - for public use and that is already the default one. So it is required to edit this option for special customizations. +* ``storage.cache_ttl`` indicates the default time to live in seconds replies from head requests has to be considered + valid. In case this value is zero, the caching mechanism is disabled. + +* ``storage.chunk_size`` is used to specify the default size (in bytes) of data chunks to be downloaded using adapters. * ``storage.local_dir`` indicates the default directory where to store local files when no path has been specified. @@ -208,17 +214,17 @@ Configuration options are: The mirroring of the files on mirrors servers has to be provided by the infrastructure. The esrally client will look for the files on the destination mirror endpoints URLs or use the original source endpoint URL in case the files are not mirrored or they have a different size from the source one. The client will prefer endpoints with the lower - latency fetching the head of the file. + latency measured by fetching the headers of the file. -* ``storage.monitor_interval`` represents the time interval (in seconds) `TransferManager` should wait for consecutive +* ``storage.monitor_interval`` represents the time interval (in seconds) ``TransferManager`` should wait for consecutive monitor operations (log transfer and connections statistics, adjust the maximum number of connections, etc.). * ``storage.multipart_size`` When the file size measured in bytes is greater than this value the file is split in chunk - of this size plus the last one ()that could be smaller). Each part will be downloaded separately and in parallel using - a dedicated connection by a worker thread and eventually from a different mirror server to load balance the network - traffic between multiple servers. If the resulting number of parts is greater than ``storage.max_workers`` and - ``storage.max_connections`` options, then the transfer of those parts exceeding these limits will be performed as - soon as a worker thread gets available or a HTTP connection get released by another thread. + of this size plus the last one. Each part will be downloaded separately and in parallel using a dedicated connection + by a worker thread and eventually from a different mirror server to load balance the network traffic between multiple + servers. If the resulting number of parts is greater than ``storage.max_workers`` and ``storage.max_connections`` + options, then the transfer of those parts exceeding these limits will be performed as soon as a worker thread gets + available or a HTTP connection get released by another thread. * ``storage.random_seed`` a string used to initialize the client random number generator. This could be used to make problems easier to reproduce in continuous integration. In most of the cases it should be left empty. @@ -229,10 +235,13 @@ HTTP Adapter This adapter can be used only to download files from public HTTP or HTTPS servers. -* ``storage.http.chunk_size`` is used to specify the size of the buffer is being used for transferring chunk of files. +It must be listed in the ``storage.adapters`` option:: + + [storage] + storage.adapters = esrally.storage.http:HTTPAdapter * ``storage.http.max_retries`` is used to configure the maximum number of retries for making HTTP adapter requests. - it accept a numeric value to simply specify total number of retries. Examples:: + It accepts a numeric value to simply specify total number of retries. Examples:: [storage] storage.http.max_retries = 3 @@ -256,6 +265,11 @@ It requires `Boto3 Client`_ to be installed and it accepts only URLs with the fo s3:///[] +It must be listed in the ``storage.adapters`` option:: + + [storage] + storage.adapters = esrally.storage.aws:S3Adapter + In the case the boto3 client is not installed, and S3 buckets are publicly readable without authentication, you can use the HTTP adapter instead, for example by using the following URL format:: @@ -294,6 +308,20 @@ Configuration options: .. _Boto3 Client: https://boto3.amazonaws.com/v1/documentation/api/latest/index.html .. _S3 Service Documentation: https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/DownloadDistS3AndCustomOrigins.html#concept_S3Origin + +track +~~~~~ + +This section specifies how tracks corpora files has to be fetched. Available options are: + +* ``track.downloader.multipart_enabled`` if ``true``, it will enable the use the new multipart ``esrally.storage`` package for + downloading corpora files. For more configuration options please have a look to the `storage`_ configuration section. + +.. warning:: + + Transfers manager implementation is experimental and under active development. Enable it only if you actually need it. + + tracks ~~~~~~ diff --git a/esrally/config.py b/esrally/config.py index 9f6c35cbd..8d33b96a2 100644 --- a/esrally/config.py +++ b/esrally/config.py @@ -16,15 +16,21 @@ # under the License. import configparser +import contextvars import logging import os.path import shutil +import typing from enum import Enum from string import Template +from typing_extensions import Self + from esrally import PROGRAM_NAME, exceptions, paths, types from esrally.utils import io +LOG = logging.getLogger(__name__) + class Scope(Enum): # Valid for all benchmarks, typically read from the configuration file @@ -40,6 +46,7 @@ class Scope(Enum): class ConfigFile: + def __init__(self, config_name=None, **kwargs): self.config_name = config_name @@ -124,22 +131,66 @@ def auto_load_local_config(base_config, additional_sections=None, config_file_cl return cfg -class Config: - EARLIEST_SUPPORTED_VERSION = 17 +CONFIG = contextvars.ContextVar[typing.Optional[types.Config]](f"{__name__}.config", default=None) + + +def get_config() -> types.Config: + cfg = CONFIG.get() + if cfg is None: + raise exceptions.ConfigError("Config not initialized.") + return cfg + + +def init_config(cfg: types.Config, *, force=False) -> types.Config: + if not force and CONFIG.get(): + raise exceptions.ConfigError(f"Config already set: {cfg}") + cfg = Config.from_config(cfg) + CONFIG.set(cfg) + return cfg - CURRENT_CONFIG_VERSION = 17 +def clear_config() -> None: + CONFIG.set(None) + + +class Config(types.Config): """ Config is the main entry point to retrieve and set benchmark properties. It provides multiple scopes to allow overriding of values on different levels (e.g. a command line flag can override the same configuration property in the config file). These levels are transparently resolved when a property is retrieved and the value on the most specific level is returned. """ - def __init__(self, config_name=None, config_file_class=ConfigFile, **kwargs): + EARLIEST_SUPPORTED_VERSION = 17 + + CURRENT_CONFIG_VERSION = 17 + + @classmethod + def from_config(cls, cfg: types.Config | None = None) -> Self: + if cfg is None: + cfg = get_config() + if isinstance(cfg, cls): + return cfg + if isinstance(cfg, types.Config): + return cls(opts_from=cfg) + raise TypeError(f"unexpected cfg: got type {type(cfg).__name__}, expected types.Config") + + def __init__(self, config_name: str | None = None, config_file_class=ConfigFile, copy_from: types.Config | None = None, **kwargs): self.name = config_name self.config_file = config_file_class(config_name, **kwargs) self._opts = {} - self._clear_config() + if copy_from is not None: + self.update(copy_from) + self._override_config() + + def update(self, cfg: types.Config): + if isinstance(cfg, Config): + self.name = cfg.name + self.config_file = cfg.config_file + self._opts.update(cfg._opts) # pylint: disable=protected-access + return + for section in cfg.all_sections(): + for name, value in cfg.all_opts(section).items(): + self.add(Scope.application, section, name, value) def add(self, scope, section: types.Section, key: types.Key, value): """ @@ -185,7 +236,10 @@ def opts(self, section: types.Section, key: types.Key, default_value=None, manda else: raise exceptions.ConfigError(f"No value for mandatory configuration: section='{section}', key='{key}'") - def all_opts(self, section: types.Section): + def all_sections(self) -> list[types.Section]: + return list(typing.get_args(types.Section)) + + def all_opts(self, section: types.Section) -> dict[str, typing.Any]: """ Finds all options in a section and returns them in a dict. @@ -233,21 +287,24 @@ def load_config(self, auto_upgrade=False): def _do_load_config(self): config = self.config_file.load() # It's possible that we just reload the configuration - self._clear_config() + self._opts = {} + self._override_config() self._fill_from_config_file(config) - def _clear_config(self): + def _override_config(self): # This map contains default options that we don't want to sprinkle all over the source code but we don't want users to change # them either - self._opts = { - (Scope.application, "source", "distribution.dir"): "distributions", - (Scope.application, "benchmarks", "track.repository.dir"): "tracks", - (Scope.application, "benchmarks", "track.default.repository"): "default", - (Scope.application, "provisioning", "node.name.prefix"): "rally-node", - (Scope.application, "provisioning", "node.http.port"): 39200, - (Scope.application, "mechanic", "team.repository.dir"): "teams", - (Scope.application, "mechanic", "team.default.repository"): "default", - } + self._opts.update( + { + (Scope.application, "source", "distribution.dir"): "distributions", + (Scope.application, "benchmarks", "track.repository.dir"): "tracks", + (Scope.application, "benchmarks", "track.default.repository"): "default", + (Scope.application, "provisioning", "node.name.prefix"): "rally-node", + (Scope.application, "provisioning", "node.http.port"): 39200, + (Scope.application, "mechanic", "team.repository.dir"): "teams", + (Scope.application, "mechanic", "team.default.repository"): "default", + } + ) def _fill_from_config_file(self, config): for section in config.sections(): @@ -279,6 +336,15 @@ def _k(self, scope, section: types.Section, key: types.Key): else: return scope, section, key + def __eq__(self, other): + if not isinstance(other, Config): + return False + if other.name != self.name: + return False + if other._opts != self._opts: + return False + return True + def migrate(config_file, current_version, target_version, out=print, i=input): logger = logging.getLogger(__name__) diff --git a/esrally/storage/__init__.py b/esrally/storage/__init__.py index 9c275a74f..3e3ce8fdd 100644 --- a/esrally/storage/__init__.py +++ b/esrally/storage/__init__.py @@ -14,13 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from esrally.storage._adapter import AdapterRegistry, Head -from esrally.storage._client import Client -from esrally.storage._executor import Executor, ThreadPoolExecutor -from esrally.storage._http import HTTPAdapter +from esrally.storage._adapter import Adapter, AdapterRegistry, Head +from esrally.storage._config import StorageConfig from esrally.storage._manager import ( TransferManager, + get_transfer_manager, init_transfer_manager, - quit_transfer_manager, - transfer_manager, + shutdown_transfer_manager, ) +from esrally.storage._transfer import Transfer diff --git a/esrally/storage/_adapter.py b/esrally/storage/_adapter.py index d96f1ed81..99a9f20b9 100644 --- a/esrally/storage/_adapter.py +++ b/esrally/storage/_adapter.py @@ -16,19 +16,21 @@ # under the License. from __future__ import annotations +import copy +import dataclasses import datetime import importlib import logging import threading from abc import ABC, abstractmethod -from collections.abc import Container, Iterable -from dataclasses import dataclass -from typing import Any, Protocol, runtime_checkable +from collections.abc import Container, Iterator +from typing import Any from typing_extensions import Self +from esrally import types +from esrally.storage._config import StorageConfig from esrally.storage._range import NO_RANGE, RangeSet -from esrally.types import Config LOG = logging.getLogger(__name__) @@ -37,17 +39,10 @@ class ServiceUnavailableError(Exception): """It is raised when an adapter refuses providing service for example because of too many requests""" -@runtime_checkable -class Writable(Protocol): - - def write(self, data: bytes) -> None: - pass - - _HEAD_CHECK_IGNORE = frozenset(["url"]) -@dataclass +@dataclasses.dataclass class Head: url: str | None = None content_length: int | None = None @@ -77,12 +72,13 @@ class Adapter(ABC): """Base class for storage class client implementation""" @classmethod + @abstractmethod def match_url(cls, url: str) -> bool: - """It returns a canonical URL in case this adapter accepts the URL, None otherwise.""" - raise NotImplementedError + """It returns whenever this adapter is able to accept this URL.""" @classmethod - def from_config(cls, cfg: Config, **kwargs: Any) -> Self: + @abstractmethod + def from_config(cls, cfg: types.Config) -> Self: """Default `Adapter` objects factory method used to create adapters from `esrally` client. Default implementation will ignore `cfg` parameter. It can be overridden from `Adapter` implementations that @@ -91,7 +87,6 @@ def from_config(cls, cfg: Config, **kwargs: Any) -> Self: :param cfg: the configuration object from which to get configuration values. :return: an adapter object. """ - return cls(**kwargs) @abstractmethod def head(self, url: str) -> Head: @@ -100,62 +95,50 @@ def head(self, url: str) -> Head: :raises ServiceUnavailableError: in case on temporary service failure. """ - def get(self, url: str, stream: Writable, head: Head | None = None) -> Head: + @abstractmethod + def get(self, url: str, *, check_head: Head | None = None) -> tuple[Head, Iterator[bytes]]: """It downloads a remote bucket object to a local file path. :param url: it represents the URL of the remote file object. - :param stream: it represents the local file stream where to write data to. - :param head: it allows to specify optional parameters: + :param check_head: it allows to specify optional parameters: - range: portion of the file to transfer (it must be empty or a continuous range). - - document_length: the number of bytes to transfer. + - content_length: the number of bytes to transfer. - crc32c the CRC32C checksum of the file. - date: the date the file has been modified. :raises ServiceUnavailableError: in case on temporary service failure. + :returns: a tuple containing the Head of the remote file, and the iterator of bytes received from the service. """ - raise NotImplementedError(f"{type(self).__name__} adapter does not implement get method.") - - -ADAPTER_CLASS_NAMES = [ - "esrally.storage._aws:S3Adapter", - "esrally.storage._http:HTTPAdapter", -] class AdapterRegistry: """AdapterClassRegistry allows to register classes of adapters to be selected according to the target URL.""" - def __init__(self, cfg: Config) -> None: + @classmethod + def from_config(cls, cfg: types.Config) -> Self: + return cls(StorageConfig.from_config(cfg)) + + def __init__(self, cfg: StorageConfig) -> None: self._classes: list[type[Adapter]] = [] self._adapters: dict[type[Adapter], Adapter] = {} self._lock = threading.Lock() self._cfg = cfg - - @classmethod - def from_config(cls, cfg: Config) -> Self: - registry = cls(cfg) - adapter_names: Iterable[str] = cfg.opts( - section="storage", key="storage.adapters", default_value=ADAPTER_CLASS_NAMES, mandatory=False - ) - if isinstance(adapter_names, str): - # It parses adapter names when it has been defined as a single string. - adapter_names = adapter_names.replace(" ", "").split(",") - for adapter_name in adapter_names: - module_name, class_name = adapter_name.split(":") + for name in self._cfg.adapters: try: - module = importlib.import_module(module_name) - except ModuleNotFoundError: - LOG.exception("unable to import module '%s'.", module_name) + self.register_class(name) + except ImportError: + LOG.exception("failed registering adapter '%s'", name) continue + + def register_class(self, cls: type[Adapter] | str, position: int | None = None) -> type[Adapter]: + if isinstance(cls, str): + module_name, class_name = cls.split(":", 1) + module = importlib.import_module(module_name) try: - obj = getattr(module, class_name) + cls = getattr(module, class_name) except AttributeError: - raise ValueError("Invalid Adapter class name: '{class_name}'.") - if not isinstance(obj, type) or not issubclass(obj, Adapter): - raise TypeError(f"'{obj}' is not a valid subclass of Adapter") - registry.register_class(obj) - return registry - - def register_class(self, cls: type[Adapter], position: int | None = None) -> type[Adapter]: + raise ValueError(f"invalid adapter class name: '{class_name}'") + if not isinstance(cls, type) or not issubclass(cls, Adapter): + raise TypeError(f"'{cls}' is not a subclass of Adapter") with self._lock: if position is None: self._classes.append(cls) @@ -176,3 +159,37 @@ def get(self, url: str) -> Adapter: self._adapters[cls] = adapter = cls.from_config(self._cfg) return adapter + + +@dataclasses.dataclass +class DummyAdapter(Adapter): + heads: dict[str, Head] = dataclasses.field(default_factory=dict) + data: dict[str, bytes] = dataclasses.field(default_factory=dict) + + @classmethod + def match_url(cls, url: str) -> bool: + return True + + @classmethod + def from_config(cls, cfg: types.Config | None = None) -> Self: + return cls() + + def head(self, url: str) -> Head: + try: + return copy.copy(self.heads[url]) + except KeyError: + raise FileNotFoundError from None + + def get(self, url: str, *, check_head: Head | None = None) -> tuple[Head, Iterator[bytes]]: + ranges: RangeSet = NO_RANGE + if check_head is not None: + ranges = check_head.ranges + if len(ranges) > 1: + raise NotImplementedError("len(head.ranges) > 1") + + data = self.data[url] + if ranges: + data = data[ranges.start : ranges.end] + + head = Head(url, content_length=ranges.size, ranges=ranges, document_length=len(data)) + return head, iter((data,)) diff --git a/esrally/storage/_client.py b/esrally/storage/_client.py index bf6efa500..8e5f4a259 100644 --- a/esrally/storage/_client.py +++ b/esrally/storage/_client.py @@ -22,108 +22,142 @@ import urllib.parse from collections import defaultdict, deque from collections.abc import Iterator +from dataclasses import dataclass from random import Random from typing import NamedTuple from typing_extensions import Self from esrally import types -from esrally.storage._adapter import ( - AdapterRegistry, - Head, - ServiceUnavailableError, - Writable, -) +from esrally.storage._adapter import AdapterRegistry, Head, ServiceUnavailableError +from esrally.storage._config import StorageConfig from esrally.storage._mirror import MirrorList from esrally.utils import pretty from esrally.utils.threads import WaitGroup, WaitGroupLimitError LOG = logging.getLogger(__name__) -MIRRORS_FILES = "~/.rally/storage-mirrors.json" -MAX_CONNECTIONS = 4 -RANDOM = Random(time.monotonic_ns()) + +class CachedHeadError(Exception): + """CachedHeadError is intended to wrap another exception. + + It is being raised when an attempt is made to retrieve a head from the cache to handle the case of recent past + failure fetching the same head. + """ + + +@dataclass +class CachedHead: + timestamp: float + head: Head | None = None + error: Exception | None = None + + def __init__(self, timestamp: float, /, head: Head | None = None, error: Exception | None = None): + self.timestamp = timestamp + if head is not None: + if error is not None: + raise ValueError("cannot specify both head and error") + self.head = head + elif error is not None: + # It creates a CachedHeadError to wraps error with the stack trace sets to here. So it will be nice to see + # the chain of exceptions to understand what gone one when debugging an exception borrowed from the cache. + # The wrapping of the original error is required to allow triggering a special handling when the error is + # recovered from an old cached request, or is got from an actual request. + try: + raise CachedHeadError(f"Cached error: {error}") from error + except CachedHeadError as ex: + self.error = ex + else: + raise ValueError("must specify either head or error") + + def get(self, *, cache_ttl: float | None = None) -> Head: + if cache_ttl is not None: + if time.monotonic() > self.timestamp + cache_ttl: + raise TimeoutError("cached head has expired") + if self.error is not None: + raise self.error + assert self.head is not None + return self.head class Client: """It handles client instances allocation allowing reusing pre-allocated instances from the same thread.""" @classmethod - def from_config( - cls, cfg: types.Config, adapters: AdapterRegistry | None = None, mirrors: MirrorList | None = None, random: Random | None = None - ) -> Self: - if adapters is None: - adapters = AdapterRegistry.from_config(cfg) - if mirrors is None: - mirrors = MirrorList.from_config(cfg) - if random is None: - random_seed = cfg.opts(section="storage", key="storage.random_seed", default_value=None, mandatory=False) - if random_seed is None: - random = RANDOM - else: - random = Random(random_seed) - max_connections = int(cfg.opts(section="storage", key="storage.max_connections", default_value=MAX_CONNECTIONS, mandatory=False)) - return cls(adapters=adapters, mirrors=mirrors, random=random, max_connections=max_connections) + def from_config(cls, cfg: types.Config | None = None) -> Self: + cfg = StorageConfig.from_config(cfg) + return cls( + adapters=AdapterRegistry.from_config(cfg), + mirrors=MirrorList.from_config(cfg), + random=Random(cfg.random_seed), + max_connections=cfg.max_connections, + cache_ttl=cfg.cache_ttl, + ) def __init__( self, adapters: AdapterRegistry, mirrors: MirrorList, - random: Random, - max_connections: int = MAX_CONNECTIONS, + random: Random | None = None, + max_connections: int = StorageConfig.DEFAULT_MAX_CONNECTIONS, + cache_ttl: float = StorageConfig.DEFAULT_CACHE_TTL, ): self._adapters: AdapterRegistry = adapters - self._cached_heads: dict[str, tuple[Head | Exception, float]] = {} + self._cached_heads: dict[str, CachedHead] = {} self._connections: dict[str, WaitGroup] = defaultdict(lambda: WaitGroup(max_count=max_connections)) self._lock = threading.Lock() self._mirrors: MirrorList = mirrors - self._random: Random = random + self._random: Random = random or Random(StorageConfig.DEFAULT_RANDOM_SEED) self._stats: dict[str, deque[ServerStats]] = defaultdict(lambda: deque(maxlen=100)) + self._cache_ttl: float = cache_ttl @property def adapters(self): return self._adapters - def head(self, url: str, ttl: float | None = None) -> Head: + def head(self, url: str, *, cache_ttl: float | None = None) -> Head: """It gets remote file headers.""" - - start_time = time.monotonic() - if ttl is not None: + if cache_ttl is None: + cache_ttl = self._cache_ttl + if cache_ttl > 0.0: # when time-to-leave is given, it looks up for pre-cached head first try: - value, last_time = self._cached_heads[url] - except KeyError: + return self._cached_heads[url].get(cache_ttl=cache_ttl) + except (KeyError, TimeoutError): + # no cached head, or it has expired. pass - else: - if start_time <= last_time + ttl: - # cached value or error is enough recent to be used. - return _head_or_raise(value) - adapter = self._adapters.get(url) + start_time = time.monotonic() try: - value = adapter.head(url) + adapter = self._adapters.get(url) + head = adapter.head(url) + error = None except Exception as ex: - LOG.error("Failed to fetch remote head for file: %s, %s", url, ex) - value = ex - end_time = time.monotonic() + head = None + error = ex + finally: + end_time = time.monotonic() with self._lock: # It records per-server time statistics. self._stats[_server_key(url)].append(ServerStats(url, start_time, end_time)) # The cached value could be an exception, or a head. In this way it will not retry previously failed # urls until the TTL expires. - self._cached_heads[url] = value, start_time + self._cached_heads[url] = CachedHead(start_time, head=head, error=error) - return _head_or_raise(value) + if error is not None: + raise error + assert head is not None + return head - def resolve(self, url: str, check: Head | None, ttl: float = 60.0) -> Iterator[Head]: + def resolve(self, url: str, *, check_head: Head | None = None, cache_ttl: float | None = None) -> Iterator[Head]: """It looks up mirror list for given URL and yield mirror heads. :param url: the remote file URL at its mirrored source location. - :param check: extra parameters to mach remote heads. + :param check_head: extra parameters to mach remote heads. - document_length: if not none it will filter out mirrors which file has an unexpected document lengths. - crc32c: if not none it will filter out mirrors which file has an unexpected crc32c checksum. - accept_ranges: if True it will filter out mirrors that are not supporting ranges. - :param ttl: the time to live value (in seconds) to use for cached heads retrieval. + :param cache_ttl: the time to live value (in seconds) to use for cached heads retrieval. :return: iterator over mirror URLs """ @@ -142,61 +176,63 @@ def resolve(self, url: str, check: Head | None, ttl: float = 60.0) -> Iterator[H urls.sort(key=lambda u: weights[u]) LOG.debug("resolve '%s': mirror urls: %s", url, urls) - if url not in urls: - # It ensures source URL is in the list so that it will be used as fall back when any mirror works. - urls.append(url) - - if len(urls) > 1: - LOG.debug("resolved mirror URLs for URL '%s': %s", url, urls) + if url not in urls: + # It ensures source URL is in the list so that it will be used as fall back when any mirror works. + urls.append(url) for u in urls: try: - got = self.head(u, ttl=ttl) + got = self.head(u, cache_ttl=cache_ttl) + if check_head is not None: + check_head.check(got) + except CachedHeadError: + # The error was previously cached, therefore it has been already logged. + pass except Exception as ex: - # The exception is already logged by head method before caching it. - LOG.error("Failed to fetch remote head for file: %s, %s", u, ex) - continue - if check is not None: - try: - check.check(got) - except ValueError as ex: - LOG.debug("unexpected mirrored file (url='%s'): %s", url, ex) - continue - yield got - - def get(self, url: str, stream: Writable, head: Head | None = None) -> Head: + if u == url: + LOG.warning("Failed to get head from original URL: '%s', %s", u, ex) + else: + LOG.warning("Failed to get head from mirror URL: '%s', %s", u, ex) + else: + yield got + + def get(self, url: str, *, check_head: Head | None = None) -> tuple[Head, Iterator[bytes]]: """It downloads a remote bucket object to a local file path. :param url: the URL of the remote file. - :param stream: the destination file stream where to write data to. - :param head: extra params for getting the file: + :param check_head: extra params for getting the file: - document_length: the document length of the file to transfer. - crc32c: the crc32c checksum of the file to transfer. - ranges: the portion of the file to transfer. :raises ServiceUnavailableError: in case on temporary service failure. """ - for got in self.resolve(url, check=head): - if got.url is None: - LOG.error("resolved mirror URL is None: %s", url) - continue + if check_head is None: + resolve_head = None + elif check_head.ranges: + resolve_head = Head( + accept_ranges=True, content_length=check_head.document_length, date=check_head.date, crc32c=check_head.crc32c + ) + else: + resolve_head = Head(content_length=check_head.content_length, date=check_head.date, crc32c=check_head.crc32c) + for got in self.resolve(url, check_head=resolve_head): + assert got.url is not None + adapter = self._adapters.get(got.url) connections = self._server_connections(got.url) try: connections.add(1) except WaitGroupLimitError: - LOG.debug("connection limit exceeded: url='%s'", url) + LOG.debug("connection limit exceeded for url '%s'", url) continue - adapter = self._adapters.get(got.url) try: - return adapter.get(url, stream, head=head) + return adapter.get(got.url, check_head=check_head) except ServiceUnavailableError as ex: - LOG.debug("service unavailable error received: url='%s' %s", url, ex) + LOG.warning("service unavailable error received: url='%s' %s", url, ex) with self._lock: # It corrects the maximum number of connections for this server. connections.max_count = max(1, connections.count) finally: connections.done() - - raise ServiceUnavailableError(f"no service available for getting URL '{url}'") + raise ServiceUnavailableError(f"no connections available for getting URL '{url}'") def _server_connections(self, url: str) -> WaitGroup: with self._lock: diff --git a/esrally/storage/_config.py b/esrally/storage/_config.py new file mode 100644 index 000000000..9b405c034 --- /dev/null +++ b/esrally/storage/_config.py @@ -0,0 +1,150 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from collections.abc import Iterable +from typing import Any + +from esrally import config +from esrally.utils import convert + +LOG = logging.getLogger(__name__) + + +class StorageConfig(config.Config): + + DEFAULT_ADAPTERS = ( + "esrally.storage.aws:S3Adapter", + "esrally.storage.http:HTTPAdapter", + ) + + @property + def adapters(self) -> tuple[str, ...]: + return convert.to_strings(self.opts("storage", "storage.adapters", self.DEFAULT_ADAPTERS, False)) + + @adapters.setter + def adapters(self, value: Iterable[str] | None) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.adapters", convert.to_strings(value)) + + DEFAULT_AWS_PROFILE = None + + @property + def aws_profile(self) -> str | None: + return self.opts("storage", "storage.aws.profile", self.DEFAULT_AWS_PROFILE, False) + + @aws_profile.setter + def aws_profile(self, value: str | None) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.aws.profile", value) + + DEFAULT_CHUNK_SIZE = 64 * 1024 + + @property + def chunk_size(self) -> int: + return int(self.opts("storage", "storage.chunk_size", self.DEFAULT_CHUNK_SIZE, False)) + + @chunk_size.setter + def chunk_size(self, value: int) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.chunk_size", value) + + DEFAULT_LOCAL_DIR = "~/.rally/storage" + + @property + def local_dir(self) -> str: + return self.opts("storage", "storage.local_dir", self.DEFAULT_LOCAL_DIR, False) + + @local_dir.setter + def local_dir(self, value: str) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.local_dir", value) + + DEFAULT_MAX_CONNECTIONS = 4 + + @property + def max_connections(self) -> int: + return int(self.opts("storage", "storage.max_connections", self.DEFAULT_MAX_CONNECTIONS, False)) + + @max_connections.setter + def max_connections(self, value: int) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.max_connections", value) + + DEFAULT_MAX_RETRIES = 3 + + @property + def max_retries(self) -> int | str: + return self.opts("storage", "storage.http.max_retries", self.DEFAULT_MAX_RETRIES, False) + + @max_retries.setter + def max_retries(self, value: int | str) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.http.max_retries", value) + + DEFAULT_MAX_WORKERS = 64 + + @property + def max_workers(self) -> int: + return int(self.opts("storage", "storage.max_workers", self.DEFAULT_MAX_WORKERS, False)) + + @max_workers.setter + def max_workers(self, value: int) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.max_workers", value) + + DEFAULT_MIRROR_FILES = ("~/.rally/storage-mirrors.json",) + + @property + def mirror_files(self) -> tuple[str, ...]: + return convert.to_strings(self.opts("storage", "storage.mirror_files", self.DEFAULT_MIRROR_FILES, False)) + + @mirror_files.setter + def mirror_files(self, value: Iterable[str] | None) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.mirror_files", convert.to_strings(value)) + + DEFAULT_MONITOR_INTERVAL = 4.0 + + @property + def monitor_interval(self) -> float: + return self.opts("storage", "storage.monitor_interval", self.DEFAULT_MONITOR_INTERVAL, False) + + @monitor_interval.setter + def monitor_interval(self, value: float) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.monitor_interval", value) + + DEFAULT_MULTIPART_SIZE = 8 * 1024 * 1024 + + @property + def multipart_size(self) -> int: + return int(self.opts("storage", "storage.multipart_size", self.DEFAULT_MULTIPART_SIZE, False)) + + @multipart_size.setter + def multipart_size(self, value: int) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.multipart_size", value) + + DEFAULT_RANDOM_SEED = None + + @property + def random_seed(self) -> Any: + return self.opts("storage", "storage.random_seed", self.DEFAULT_RANDOM_SEED, False) + + @random_seed.setter + def random_seed(self, value: Any) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.random_seed", value) + + DEFAULT_CACHE_TTL = 60.0 + + @property + def cache_ttl(self) -> float: + return self.opts("storage", "storage.cache_ttl", self.DEFAULT_CACHE_TTL, False) + + @cache_ttl.setter + def cache_ttl(self, value: float) -> None: + self.add(config.Scope.applicationOverride, "storage", "storage.cache_ttl", value) diff --git a/esrally/storage/_executor.py b/esrally/storage/_executor.py index 62af53ed3..47de8dc4d 100644 --- a/esrally/storage/_executor.py +++ b/esrally/storage/_executor.py @@ -19,9 +19,8 @@ from typing_extensions import Self -from esrally.types import Config - -MAX_WORKERS = 32 +from esrally import types +from esrally.storage._config import StorageConfig @runtime_checkable @@ -49,6 +48,35 @@ def submit(self, fn, /, *args, **kwargs): class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor, Executor): @classmethod - def from_config(cls, cfg: Config) -> Self: - max_workers = int(cfg.opts("storage", "storage.max_workers", MAX_WORKERS, mandatory=False)) - return cls(max_workers=max_workers, thread_name_prefix="esrally.storage.executor") + def from_config(cls, cfg: types.Config | None = None) -> Self: + cfg = StorageConfig.from_config(cfg) + return cls(max_workers=cfg.max_workers, thread_name_prefix=__name__) + + +def executor_from_config(cfg: types.Config | None = None) -> ThreadPoolExecutor: + return ThreadPoolExecutor.from_config(cfg) + + +class DummyExecutor(Executor): + + def __init__(self): + self.tasks: list[tuple] | None = [] + + def submit(self, fn, /, *args, **kwargs): + """Submits a callable to be executed with the given arguments. + + Schedules the callable to be executed as fn(*args, **kwargs). + """ + if self.tasks is None: + raise RuntimeError("Executor already closed") + self.tasks.append((fn, args, kwargs)) + + def execute_tasks(self): + if self.tasks is None: + raise RuntimeError("Executor already closed") + tasks, self.tasks = self.tasks, [] + for fn, args, kwargs in tasks: + fn(*args, **kwargs) + + def shutdown(self): + self.tasks = None diff --git a/esrally/storage/_manager.py b/esrally/storage/_manager.py index 1cd65a8ba..6191c3157 100644 --- a/esrally/storage/_manager.py +++ b/esrally/storage/_manager.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import atexit +import contextvars import logging import os import threading @@ -22,120 +23,95 @@ from typing_extensions import Self from esrally import types -from esrally.storage._client import MAX_CONNECTIONS, Client -from esrally.storage._executor import MAX_WORKERS, Executor, ThreadPoolExecutor -from esrally.storage._transfer import MULTIPART_SIZE, Transfer +from esrally.storage._client import Client +from esrally.storage._config import StorageConfig +from esrally.storage._executor import Executor, executor_from_config +from esrally.storage._transfer import Transfer from esrally.utils.threads import ContinuousTimer LOG = logging.getLogger(__name__) -LOCAL_DIR = "~/.rally/storage" -MONITOR_INTERVAL = 2.0 # Seconds -THREAD_NAME_PREFIX = "esrally.storage.transfer-worker" - class TransferManager: """It creates and perform file transfer operations in background.""" @classmethod - def from_config(cls, cfg: types.Config, client: Client | None = None, executor: Executor | None = None) -> Self: + def from_config(cls, cfg: types.Config | None = None) -> Self: """It creates a TransferManager with initialization values taken from given configuration.""" - local_dir = cfg.opts(section="storage", key="storage.local_dir", default_value=LOCAL_DIR, mandatory=False) - monitor_interval = cfg.opts(section="storage", key="storage.monitor_interval", default_value=MONITOR_INTERVAL, mandatory=False) - max_connections = cfg.opts(section="storage", key="storage.max_connections", default_value=MAX_CONNECTIONS, mandatory=False) - max_workers = cfg.opts(section="storage", key="storage.max_workers", default_value=MAX_WORKERS, mandatory=False) - multipart_size = cfg.opts(section="storage", key="storage.multipart_size", default_value=MULTIPART_SIZE, mandatory=False) - if client is None: - client = Client.from_config(cfg) - if executor is None: - executor = ThreadPoolExecutor.from_config(cfg) + cfg = StorageConfig.from_config(cfg) return cls( - client=client, - executor=executor, - local_dir=local_dir, - monitor_interval=float(monitor_interval), - multipart_size=multipart_size, - max_connections=int(max_connections), - max_workers=int(max_workers), + cfg=cfg, + client=Client.from_config(cfg), + executor=executor_from_config(cfg), ) def __init__( self, + cfg: StorageConfig, client: Client, executor: Executor, - local_dir: str = LOCAL_DIR, - monitor_interval: float = MONITOR_INTERVAL, - max_connections: int = MAX_CONNECTIONS, - max_workers: int = MAX_WORKERS, - multipart_size: int = MULTIPART_SIZE, ): """It manages files transfers. It executes file transfers in background. It periodically logs the status of transfers that are in progress. - - :param local_dir: default directory used to download files to, or upload files from. - :param monitor_interval: time interval (in seconds) separating background calls to the _monitor method. + :param cfg: Configuration object. :param client: _client.Client instance used to allocate/reuse storage adapters. - :param multipart_size: length of every part when working with multipart. - :param max_workers: max number of connections per remote server when working with multipart. + :param executor: Executor instance used to execute transfer operations. """ + self.cfg = cfg self._client = client self._executor = executor self._lock = threading.Lock() + self._transfers: dict[str, Transfer] = {} - local_dir = os.path.expanduser(local_dir) - if not os.path.isdir(local_dir): - os.makedirs(local_dir) - - self._local_dir = local_dir - if monitor_interval <= 0: - raise ValueError(f"invalid monitor interval: {monitor_interval} <= 0") + if cfg.max_workers < 1: + raise ValueError(f"invalid max_workers: {cfg.max_workers} < 1") - self._transfers: list[Transfer] = [] + if cfg.max_connections < 1: + raise ValueError(f"invalid max_connections: {cfg.max_connections} < 1") - if max_workers < 1: - raise ValueError(f"invalid max_workers: {max_workers} < 1") - self._max_workers = max_workers + if cfg.multipart_size < 1024 * 1024: + raise ValueError(f"invalid multipart_size: {cfg.multipart_size} < {1024 * 1024}") - if max_connections < 1: - raise ValueError(f"invalid max_connections: {max_connections} < 1") - self._max_connections = max_connections - - if multipart_size < 1024 * 1024: - raise ValueError(f"invalid multipart_size: {multipart_size} < {1024 * 1024}") - self._multipart_size = multipart_size - - self._monitor_timer = ContinuousTimer(interval=monitor_interval, function=self.monitor, name="esrally.storage.transfer-monitor") + if cfg.monitor_interval <= 0: + raise ValueError(f"invalid monitor interval: {cfg.monitor_interval} <= 0") + self._monitor_timer = ContinuousTimer(interval=cfg.monitor_interval, function=self.monitor, name="esrally.storage.transfer-monitor") self._monitor_timer.start() def shutdown(self): with self._lock: transfers = self._transfers - self._transfers = [] - for tr in transfers: - tr.close() + self._transfers = {} + LOG.info("Shutting down transfer manager...") + for tr in transfers.values(): + try: + if tr.finished: + continue + LOG.warning("Cancelling transfer: %s...", tr) + tr.close() + except Exception as ex: + LOG.error("error closing transfer: %s, %s", tr.url, ex) self._monitor_timer.cancel() + LOG.info("Transfer manager shut down.") def get(self, url: str, path: os.PathLike | str | None = None, document_length: int | None = None) -> Transfer: - """It starts a new transfer of a file from local path to a remote url. + """It starts a new transfer of a file from a remote url to a local path. :param url: remote file address. :param path: local file address. :param document_length: the expected file size in bytes. :return: started transfer object. """ - return self._transfer(url=url, path=path, document_length=document_length) - - def _transfer( - self, - url: str, - path: os.PathLike | str | None = None, - document_length: int | None = None, - ) -> Transfer: if path is None: - path = os.path.join(self._local_dir, url) + path = os.path.join(self.cfg.local_dir, url) # This also ensures the path is a string path = os.path.normpath(os.path.expanduser(path)) + os.makedirs(os.path.dirname(path), exist_ok=True) + + tr = self._transfers.get(path) + if tr is not None: + tr.start() + return tr head = self._client.head(url) if document_length is not None and head.content_length != document_length: @@ -146,7 +122,7 @@ def _transfer( path=path, document_length=head.content_length, executor=self._executor, - multipart_size=self._multipart_size, + multipart_size=self.cfg.multipart_size, crc32c=head.crc32c, ) @@ -154,7 +130,7 @@ def _transfer( # requesting for the first worker threads. In this way it will avoid requesting more worker threads than # the allowed per-transfer connections. with self._lock: - self._transfers.append(tr) + self._transfers[tr.path] = tr self._update_transfers() tr.start() return tr @@ -166,23 +142,23 @@ def monitor(self): @property def max_connections(self) -> int: with self._lock: - max_connections = self._max_connections + max_connections = self.cfg.max_connections number_of_transfers = len(self._transfers) if number_of_transfers > 0: - max_connections = min(max_connections, (self._max_workers // number_of_transfers) + 1) + max_connections = min(max_connections, (self.cfg.max_workers // number_of_transfers) + 1) return max_connections def _update_transfers(self) -> None: """It executes periodic update operations on every unfinished transfer.""" with self._lock: # It first removes finished transfers. - self._transfers = transfers = [tr for tr in self._transfers if not tr.finished] + self._transfers = transfers = {tr.path: tr for tr in self._transfers.values() if not tr.finished} if not transfers: return # It updates max_connections value for each transfer max_connections = self.max_connections - for tr in transfers: + for tr in transfers.values(): # It updates the limit of the number of connections for every transfer because it varies in function of # the number of transfers in progress. tr.max_connections = max_connections @@ -194,36 +170,37 @@ def _update_transfers(self) -> None: tr.start() # It logs updated statistics for every transfer. - LOG.info("Transfers in progress:\n %s", "\n ".join(tr.info() for tr in transfers)) + LOG.info("Transfers in progress:\n %s", "\n ".join(tr.info() for tr in transfers.values())) -_LOCK = threading.Lock() -_MANAGER: TransferManager | None = None +_MANAGER = contextvars.ContextVar[TransferManager | None](f"{__name__}.transfer_manager", default=None) -def init_transfer_manager(cfg: types.Config, client: Client | None = None, executor: Executor | None = None) -> bool: - global _MANAGER - with _LOCK: - if _MANAGER is not None: - LOG.debug("Transfer manager already initialized") - return False - _MANAGER = TransferManager.from_config(cfg, client=client, executor=executor) - atexit.register(_MANAGER.shutdown) - return True +def init_transfer_manager(*, cfg: types.Config | None = None, shutdown_at_exit: bool = True) -> TransferManager: + manager = _MANAGER.get() + if manager is not None: + return manager + # Initialize transfer manager. + cfg = StorageConfig.from_config(cfg) + manager = TransferManager.from_config(cfg) + _MANAGER.set(manager) -def quit_transfer_manager() -> bool: - global _MANAGER - with _LOCK: - if _MANAGER is None: - LOG.debug("Transfer manager not initialized.") - return False - _MANAGER.shutdown() - _MANAGER = None - return True + if shutdown_at_exit: + atexit.register(manager.shutdown) + return manager -def transfer_manager() -> TransferManager: - if _MANAGER is None: +def get_transfer_manager() -> TransferManager: + manager = _MANAGER.get() + if manager is None: raise RuntimeError("Transfer manager not initialized.") - return _MANAGER + return manager + + +def shutdown_transfer_manager() -> None: + manager = _MANAGER.get() + if manager is None: + return + _MANAGER.set(None) + return manager.shutdown() diff --git a/esrally/storage/_mirror.py b/esrally/storage/_mirror.py index 2dcd822de..c32b71423 100644 --- a/esrally/storage/_mirror.py +++ b/esrally/storage/_mirror.py @@ -15,37 +15,42 @@ # specific language governing permissions and limitations # under the License. import json +import logging import os from collections import defaultdict from collections.abc import Iterable, Mapping from typing_extensions import Self -from esrally.types import Config +from esrally import types +from esrally.storage._config import StorageConfig +from esrally.utils import convert -MIRRORS_FILES = "" +LOG = logging.getLogger(__name__) class MirrorList: @classmethod - def from_config(cls, cfg: Config) -> Self: - mirror_files = [] - for filename in cfg.opts(section="storage", key="storage.mirrors_files", default_value=MIRRORS_FILES, mandatory=False).split(","): - filename = filename.strip() - if filename: - mirror_files.append(filename) - return cls(mirror_files=mirror_files) + def from_config(cls, cfg: types.Config | StorageConfig | None = None) -> Self: + return cls(cfg=StorageConfig.from_config(cfg)) def __init__( self, + cfg: StorageConfig | None = None, + mirror_files: Iterable[str] | None = None, urls: Mapping[str, Iterable[str]] | None = None, - mirror_files: Iterable[str] = tuple(), ): + self._cfg = cfg self._urls: dict[str, set] = defaultdict(set) - self._mirror_files = tuple(mirror_files or []) - for path in self._mirror_files: - self._update(_load_file(path)) + mirror_files = set(convert.to_strings(mirror_files)) + if cfg is not None: + mirror_files.update(cfg.mirror_files) + for path in mirror_files: + try: + self._update(_load_file(path)) + except FileNotFoundError as ex: + LOG.warning("failed loading mirror file '%s': %s", path, ex) if urls is not None: self._update(urls) diff --git a/esrally/storage/_transfer.py b/esrally/storage/_transfer.py index 9a376da46..46ad37b60 100644 --- a/esrally/storage/_transfer.py +++ b/esrally/storage/_transfer.py @@ -27,7 +27,8 @@ from typing import BinaryIO from esrally.storage._adapter import Head, ServiceUnavailableError -from esrally.storage._client import MAX_CONNECTIONS, Client +from esrally.storage._client import Client +from esrally.storage._config import StorageConfig from esrally.storage._executor import Executor from esrally.storage._range import ( MAX_LENGTH, @@ -51,9 +52,6 @@ class TransferStatus(enum.Enum): FAILED = 5 -MULTIPART_SIZE = 8 * 1024 * 1024 - - class Transfer: """Transfers class implements multipart file transfers by submitting tasks to an Executor. @@ -83,7 +81,7 @@ def __init__( todo: RangeSet = NO_RANGE, done: RangeSet = NO_RANGE, multipart_size: int | None = None, - max_connections: int = MAX_CONNECTIONS, + max_connections: int = StorageConfig.DEFAULT_MAX_CONNECTIONS, resume: bool = True, crc32c: str | None = None, ): @@ -107,7 +105,7 @@ def __init__( multipart_size = MAX_LENGTH else: # By default, it will enable multipart with a hardcoded size. - multipart_size = MULTIPART_SIZE + multipart_size = StorageConfig.DEFAULT_MULTIPART_SIZE if not todo: # By default, it will transfer the whole file. todo = Range(0, MAX_LENGTH) @@ -282,15 +280,24 @@ def _run(self): self.start() assert isinstance(fd, FileWriter) # It downloads the part of the file from a remote location. - head = self.client.get( - self.url, fd, head=Head(ranges=fd.ranges, document_length=self._document_length, crc32c=self._crc32c) - ) - if head.document_length is not None: + if fd.ranges: + check_head = Head( + ranges=fd.ranges, content_length=fd.ranges.size, document_length=self._document_length, crc32c=self._crc32c + ) + else: + check_head = Head(content_length=self._document_length, crc32c=self._crc32c) + head, content = self.client.get(self.url, check_head=check_head) + # When ranges are not given, then content_length replaces document_length. + document_length = head.document_length or head.content_length + if document_length: # It checks the size of the file it downloaded the data from. - self.document_length = head.document_length + self.document_length = document_length if head.crc32c is not None: # It checks the crc32c check sum of the file it downloaded the data from. self.crc32c = head.crc32c + for chunk in content: + if chunk: + fd.write(chunk) except StreamClosedError as ex: LOG.info("transfer cancelled: %s: %s", self.url, ex) cancelled = True @@ -346,7 +353,7 @@ def close(self): # cancelled. fd.close() except Exception as ex: - LOG.warning("error closing file descriptor %r: %s", fd.path, ex) + LOG.error("error closing file descriptor %r: %s", fd.path, ex) @contextmanager def _open(self) -> Iterator[FileDescriptor | None]: diff --git a/esrally/storage/_aws.py b/esrally/storage/aws.py similarity index 80% rename from esrally/storage/_aws.py rename to esrally/storage/aws.py index 4a256b26c..65b0bf80c 100644 --- a/esrally/storage/_aws.py +++ b/esrally/storage/aws.py @@ -19,76 +19,68 @@ import logging import os import urllib.parse -from collections.abc import Mapping +from collections.abc import Iterator, Mapping from typing import Any, NamedTuple, Protocol, runtime_checkable import boto3 from botocore.response import StreamingBody from typing_extensions import Self -from esrally.storage._adapter import Adapter, Head, Writable -from esrally.storage._http import ( - CHUNK_SIZE, +from esrally import types +from esrally.storage._adapter import Adapter, Head +from esrally.storage._config import StorageConfig +from esrally.storage.http import ( head_to_headers, parse_accept_ranges, parse_content_range, parse_hashes_from_headers, ) -from esrally.types import Config LOG = logging.getLogger(__name__) -AWS_PROFILE: str | None = None - class S3Adapter(Adapter): """Adapter class for s3:// scheme protocol""" @classmethod - def from_config(cls, cfg: Config, chunk_size: int | None = None, aws_profile: str | None = None, **kwargs: Any) -> Self: - if aws_profile is None: - aws_profile = cfg.opts("storage", "storage.aws.profile", default_value=AWS_PROFILE, mandatory=False) - if chunk_size is None: - chunk_size = int(cfg.opts("storage", "storage.http.chunk_size", CHUNK_SIZE, False)) - return super().from_config(cfg, aws_profile=aws_profile, chunk_size=chunk_size, **kwargs) + def match_url(cls, url: str) -> bool: + return url.startswith("s3://") + + @classmethod + def from_config(cls, cfg: types.Config | None = None) -> Self: + cfg = StorageConfig.from_config(cfg) + return cls(aws_profile=cfg.aws_profile, chunk_size=cfg.chunk_size) def __init__( self, - aws_profile: str | None = AWS_PROFILE, - chunk_size: int = CHUNK_SIZE, + aws_profile: str | None = StorageConfig.DEFAULT_AWS_PROFILE, + chunk_size: int = StorageConfig.DEFAULT_CHUNK_SIZE, s3_client: S3Client | None = None, ) -> None: + if chunk_size < 0: + raise ValueError("Chunk size must be positive") self.chunk_size = chunk_size self.aws_profile = aws_profile self._s3_client = s3_client - @classmethod - def match_url(cls, url: str) -> bool: - return url.startswith("s3://") - def head(self, url: str) -> Head: address = S3Address.from_url(url) res = self._s3.head_object(Bucket=address.bucket, Key=address.key) return head_from_response(url, res) - def get(self, url: str, stream: Writable, head: Head | None = None) -> Head: + def get(self, url: str, *, check_head: Head | None = None) -> tuple[Head, Iterator[bytes]]: headers: dict[str, Any] = {} - head_to_headers(head, headers) + head_to_headers(check_head, headers) address = S3Address.from_url(url) res = self._s3.get_object(Bucket=address.bucket, Key=address.key, **headers) - ret = head_from_response(url, res) - if head is not None: - head.check(ret) + head = head_from_response(url, res) + if check_head is not None: + check_head.check(head) body: StreamingBody | None = res.get("Body") if body is None: raise RuntimeError("S3 client returned no body.") - for chunk in body.iter_chunks(self.chunk_size): - if chunk: - stream.write(chunk) - return ret - - _s3_client = None + return head, body.iter_chunks(self.chunk_size) @property def _s3(self) -> S3Client: diff --git a/esrally/storage/get.py b/esrally/storage/get.py new file mode 100644 index 000000000..bd1561a76 --- /dev/null +++ b/esrally/storage/get.py @@ -0,0 +1,67 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import sys +import time +import urllib + +from esrally.storage import StorageConfig, Transfer, init_transfer_manager + +LOG = logging.getLogger("esrally.storage") + + +def main(): + logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(name)s %(message)s") + urls = sys.argv[1:] + if not urls: + sys.stderr.write(f"usage: python3 '{sys.argv[0]}' [urls]\n") + raise sys.exit(1) + + cfg = StorageConfig() + cfg.load_config() + manager = init_transfer_manager(cfg=cfg) + + all: dict[str, Transfer | None] = {normalise_url(url): None for url in urls} + try: + unfinished = set(all) + while unfinished: + LOG.info("Downloading files: \n - %s", "\n - ".join(f"{u}: {(s := all.get(u)) and s.progress or '?'}%" for u in unfinished)) + for url in list(unfinished): + transfer: Transfer = manager.get(url) + assert isinstance(transfer, Transfer) + all[url] = transfer + if transfer.finished: + unfinished.remove(url) + LOG.info("Download terminated: %s", transfer.url) + continue + if unfinished: + time.sleep(2.0) + finally: + LOG.info("Downloaded files: \n - %s", "\n - ".join(f"{u}: {s and s.progress or '?'}%" for u, s in all.items())) + + +def normalise_url(url: str) -> str: + u = urllib.parse.urlparse(url, "https") + if not u.netloc: + u = urllib.parse.urlparse(f"{u.scheme}://rally-tracks.elastic.co/{u.path}") + return u.geturl() + + +if __name__ == "__main__": + main() diff --git a/esrally/storage/_http.py b/esrally/storage/http.py similarity index 76% rename from esrally/storage/_http.py rename to esrally/storage/http.py index aeb402224..92d2122e4 100644 --- a/esrally/storage/_http.py +++ b/esrally/storage/http.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from collections.abc import Mapping, MutableMapping +from collections.abc import Iterator, Mapping, MutableMapping from datetime import datetime from typing import Any @@ -26,34 +26,22 @@ from requests.structures import CaseInsensitiveDict from typing_extensions import Self -from esrally.storage._adapter import Adapter, Head, ServiceUnavailableError, Writable +from esrally import types +from esrally.storage._adapter import Adapter, Head, ServiceUnavailableError +from esrally.storage._config import StorageConfig from esrally.storage._range import NO_RANGE, RangeSet, rangeset -from esrally.types import Config LOG = logging.getLogger(__name__) -# Size of the buffers used for file transfer content. -CHUNK_SIZE = 1 * 1024 * 1024 - -# It limits the maximum number of connection retries. -MAX_RETRIES = 10 - - class Session(requests.Session): @classmethod - def from_config(cls, cfg: Config) -> Self: - max_retries: urllib3.Retry | int | None = 0 - max_retries_text = cfg.opts("storage", "storage.http.max_retries", MAX_RETRIES, mandatory=False) - if max_retries_text: - try: - max_retries = int(max_retries_text) - except ValueError: - max_retries = urllib3.Retry(**json.loads(max_retries_text)) - return cls(max_retries=max_retries) - - def __init__(self, max_retries: urllib3.Retry | int | None = 0) -> None: + def from_config(cls, cfg: types.Config | None = None) -> Self: + cfg = StorageConfig.from_config(cfg) + return cls(max_retries=parse_max_retries(cfg.max_retries)) + + def __init__(self, max_retries: urllib3.Retry | int | None = None) -> None: super().__init__() if max_retries: adapter = requests.adapters.HTTPAdapter(max_retries=max_retries) @@ -69,18 +57,15 @@ def match_url(cls, url: str) -> bool: return url.startswith("http://") or url.startswith("https://") @classmethod - def from_config(cls, cfg: Config, session: requests.Session | None = None, chunk_size: int | None = None, **kwargs: Any) -> Self: - if session is None: - session = Session.from_config(cfg) - if chunk_size is None: - chunk_size = int(cfg.opts("storage", "storage.http.chunk_size", CHUNK_SIZE, False)) - return super().from_config(cfg, session=session, chunk_size=chunk_size, **kwargs) + def from_config(cls, cfg: types.Config | None = None) -> Self: + cfg = StorageConfig.from_config(cfg) + return cls(session=Session.from_config(cfg), chunk_size=cfg.chunk_size) - def __init__(self, session: requests.Session | None = None, chunk_size: int = CHUNK_SIZE): + def __init__(self, session: requests.Session | None = None, chunk_size: int = StorageConfig.DEFAULT_CHUNK_SIZE): if session is None: session = requests.Session() - self.chunk_size = chunk_size self.session = session + self.chunk_size = chunk_size def head(self, url: str) -> Head: with self.session.head(url, allow_redirects=True) as res: @@ -89,22 +74,23 @@ def head(self, url: str) -> Head: res.raise_for_status() return head_from_headers(url, res.headers) - def get(self, url: str, stream: Writable, head: Head | None = None) -> Head: + def get(self, url: str, *, check_head: Head | None = None) -> tuple[Head, Iterator[bytes]]: headers: MutableMapping[str, str] = CaseInsensitiveDict() - head_to_headers(head, headers) - with self.session.get(url, stream=True, allow_redirects=True, headers=headers) as res: - if res.status_code == 503: - raise ServiceUnavailableError() - res.raise_for_status() + head_to_headers(check_head, headers) + res = self.session.get(url, stream=True, allow_redirects=True, headers=headers) + if res.status_code == 503: + raise ServiceUnavailableError() + res.raise_for_status() + + head = head_from_headers(url, res.headers) + if check_head is not None: + check_head.check(head) - got = head_from_headers(url, res.headers) - if head is not None: - head.check(got) + def iter_content() -> Iterator[bytes]: + with res: + yield from res.iter_content(chunk_size=self.chunk_size) - for chunk in res.iter_content(self.chunk_size): - if chunk: - stream.write(chunk) - return got + return head, iter_content() _ACCEPT_RANGES_HEADER = "Accept-Ranges" @@ -128,9 +114,10 @@ def head_from_headers(url: str, headers: Mapping[str, Any]) -> Head: def head_to_headers(head: Head | None, headers: MutableMapping[str, str]) -> None: - if head is not None: - date_to_headers(head.date, headers) - ranges_to_headers(head.ranges, headers) + if head is None: + return + date_to_headers(head.date, headers) + ranges_to_headers(head.ranges, headers) def date_to_headers(date: datetime | None, headers: MutableMapping[str, str]) -> None: @@ -226,3 +213,19 @@ def parse_hashes_from_headers(headers: Mapping[str, Any]) -> dict[str, str]: continue hashes[name] = value return hashes + + +def parse_max_retries(max_retries: str | int | urllib3.Retry) -> urllib3.Retry | int: + if isinstance(max_retries, str): + max_retries = max_retries.strip() + if not max_retries: + max_retries = StorageConfig.DEFAULT_MAX_RETRIES + try: + max_retries = int(max_retries) + except ValueError: + assert isinstance(max_retries, str) + max_retries = urllib3.Retry(**json.loads(max_retries)) + if isinstance(max_retries, int): + if max_retries <= 0: + raise ValueError("max_retries must be a positive integer") + return max_retries diff --git a/esrally/storage/testing.py b/esrally/storage/testing.py deleted file mode 100644 index 7121917ca..000000000 --- a/esrally/storage/testing.py +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import copy -from collections.abc import Iterable - -from esrally.storage._adapter import Adapter, Head, Writable -from esrally.storage._executor import Executor -from esrally.storage._range import NO_RANGE, RangeSet - - -class DummyAdapter(Adapter): - - HEADS: Iterable[Head] = tuple() - DATA: dict[str, bytes] = {} - - @classmethod - def match_url(cls, url: str) -> bool: - return True - - def __init__(self, heads: Iterable[Head] | None = None, data: dict[str, bytes] | None = None) -> None: - if heads is None: - heads = self.HEADS - if data is None: - data = self.DATA - self.heads: dict[str, Head] = {h.url: h for h in heads if h.url is not None} - self.data: dict[str, bytes] = copy.deepcopy(data) - - def head(self, url: str) -> Head: - try: - return copy.copy(self.heads[url]) - except KeyError: - raise FileNotFoundError from None - - def get(self, url: str, stream: Writable, head: Head | None = None) -> Head: - ranges: RangeSet = NO_RANGE - if head is not None: - ranges = head.ranges - if len(ranges) > 1: - raise NotImplementedError("len(head.ranges) > 1") - data = self.data[url] - if ranges: - stream.write(data[ranges.start : ranges.end]) - return Head(url, content_length=ranges.size, ranges=ranges, document_length=len(data)) - - stream.write(data) - return Head(url, content_length=len(data)) - - -class DummyExecutor(Executor): - - def __init__(self): - self.tasks: list[tuple] | None = [] - - def submit(self, fn, /, *args, **kwargs): - """Submits a callable to be executed with the given arguments. - - Schedules the callable to be executed as fn(*args, **kwargs). - """ - if self.tasks is None: - raise RuntimeError("Executor already closed") - self.tasks.append((fn, args, kwargs)) - - def execute_tasks(self): - if self.tasks is None: - raise RuntimeError("Executor already closed") - tasks, self.tasks = self.tasks, [] - for fn, args, kwargs in tasks: - fn(*args, **kwargs) - - def shutdown(self): - self.tasks = None diff --git a/esrally/track/loader.py b/esrally/track/loader.py index b0e74ecad..760cee59e 100644 --- a/esrally/track/loader.py +++ b/esrally/track/loader.py @@ -24,8 +24,7 @@ import sys import tempfile import urllib.error -from collections.abc import Generator -from typing import Callable, Optional +from collections.abc import Callable, Generator import jinja2 import jinja2.exceptions @@ -33,7 +32,16 @@ import tabulate from jinja2 import meta -from esrally import PROGRAM_NAME, config, exceptions, paths, time, types, version +from esrally import ( + PROGRAM_NAME, + config, + exceptions, + paths, + storage, + time, + types, + version, +) from esrally.track import params, track from esrally.track.track import Parallel from esrally.utils import ( @@ -49,6 +57,8 @@ serverless, ) +LOG = logging.getLogger(__name__) + class TrackSyntaxError(exceptions.InvalidSyntax): """ @@ -91,13 +101,16 @@ def on_prepare_track(self, track: track.Track, data_root_dir: str) -> Generator[ class TrackProcessorRegistry: + + @classmethod + def from_config(cls, cfg: config.Config): + return cls(cfg) + def __init__(self, cfg: types.Config): self.required_processors = [TaskFilterTrackProcessor(cfg), ServerlessFilterTrackProcessor(cfg), TestModeTrackProcessor(cfg)] self.track_processors = [] - self.offline = cfg.opts("system", "offline.mode") - self.test_mode = cfg.opts("track", "test.mode.enabled", mandatory=False, default_value=False) - self.base_config = cfg - self.custom_configuration = False + self.base_config: types.Config = cfg + self.custom_configuration: bool = False def register_track_processor(self, processor): if not self.custom_configuration: @@ -109,7 +122,7 @@ def register_track_processor(self, processor): if hasattr(processor, "cfg"): processor.cfg = self.base_config if hasattr(processor, "downloader"): - processor.downloader = Downloader(self.offline, self.test_mode) + processor.downloader = Downloader.from_config(cfg=self.base_config) if hasattr(processor, "decompressor"): processor.decompressor = Decompressor() self.track_processors.append(processor) @@ -247,12 +260,12 @@ def _load_single_track(cfg: types.Config, track_repository, track_name, install_ processor.on_after_load_track(current_track) return current_track except FileNotFoundError as e: - logging.getLogger(__name__).exception("Cannot load track [%s]", track_name) + LOG.exception("Cannot load track [%s]", track_name) raise exceptions.SystemSetupError( f"Cannot load track [{track_name}]. List the available tracks with [{PROGRAM_NAME} list tracks]." ) from e - except BaseException: - logging.getLogger(__name__).exception("Cannot load track [%s]", track_name) + except Exception: + LOG.exception("Cannot load track [%s]", track_name) raise @@ -278,7 +291,7 @@ def load_track_plugins( """ repo = track_repo(cfg, fetch=force_update, update=force_update) track_plugin_path = repo.track_dir(track_name) - logging.getLogger(__name__).debug("Invoking plugin_reader with name [%s] resolved to path [%s]", track_name, track_plugin_path) + LOG.debug("Invoking plugin_reader with name [%s] resolved to path [%s]", track_name, track_plugin_path) plugin_reader = TrackPluginReader(track_plugin_path, register_runner, register_scheduler, register_track_processor) if plugin_reader.can_load(): @@ -455,7 +468,7 @@ class DefaultTrackPreparator(TrackProcessor): def __init__(self): super().__init__() # just declare here, will be injected later - self.cfg: Optional[types.Config] = None + self.cfg: types.Config | None = None self.downloader = None self.decompressor = None self.track = None @@ -465,9 +478,7 @@ def prepare_docs(cfg: types.Config, track, corpus, preparator): for document_set in corpus.documents: if document_set.is_bulk: data_root = data_dir(cfg, track.name, corpus.name) - logging.getLogger(__name__).info( - "Resolved data root directory for document corpus [%s] in track [%s] to [%s].", corpus.name, track.name, data_root - ) + LOG.info("Resolved data root directory for document corpus [%s] in track [%s] to [%s].", corpus.name, track.name, data_root) if len(data_root) == 1: preparator.prepare_document_set(document_set, data_root[0]) # attempt to prepare everything in the current directory and fallback to the corpus directory @@ -482,8 +493,6 @@ def on_prepare_track(self, track, data_root_dir) -> Generator[tuple[Callable, di class Decompressor: - def __init__(self): - self.logger = logging.getLogger(__name__) def decompress(self, archive_path, documents_path, uncompressed_size): if uncompressed_size: @@ -494,7 +503,7 @@ def decompress(self, archive_path, documents_path, uncompressed_size): else: msg = f"Decompressing track data from [{archive_path}] to [{documents_path}] ... " - console.info(msg, end="", flush=True, logger=self.logger) + console.info(msg, end="", flush=True, logger=LOG) io.decompress(archive_path, io.dirname(archive_path)) console.println("[OK]") if not os.path.isfile(documents_path): @@ -511,51 +520,76 @@ def decompress(self, archive_path, documents_path, uncompressed_size): class Downloader: - def __init__(self, offline, test_mode): + + @classmethod + def from_config(cls, cfg: types.Config): + offline: bool = convert.to_bool(cfg.opts("system", "offline.mode", mandatory=False, default_value=False)) + test_mode: bool = convert.to_bool(cfg.opts("track", "test.mode.enabled", mandatory=False, default_value=False)) + use_transfer_manager: bool = convert.to_bool( + cfg.opts("track", "track.downloader.multipart_enabled", mandatory=False, default_value=False) + ) + storage_config: storage.StorageConfig | None = None + if use_transfer_manager: + storage_config = storage.StorageConfig.from_config(cfg) + return cls(offline=offline, test_mode=test_mode, storage_config=storage_config) + + def __init__(self, offline: bool, test_mode: bool, storage_config: storage.StorageConfig | None = None): self.offline = offline self.test_mode = test_mode - self.logger = logging.getLogger(__name__) + self.storage_config = storage_config - def download(self, base_url, target_path, size_in_bytes): + def download(self, base_url: str, target_path: str, size_in_bytes: int | None = None) -> None: file_name = os.path.basename(target_path) - if not base_url: raise exceptions.DataError("Cannot download data because no base URL is provided.") if self.offline: raise exceptions.SystemSetupError(f"Cannot find [{target_path}]. Please disable offline mode and retry.") - if base_url.endswith("/"): - separator = "" + # It joins manually as `urllib.parse.urljoin` does not work with S3 or GS URL schemes. + data_url = f"{base_url.rstrip('/')}/{file_name}" + if self.storage_config is not None: + manager = storage.init_transfer_manager(cfg=self.storage_config) + LOG.info("Downloading data from [%s] to [%s] using transfer manager...", data_url, target_path) + try: + manager.get(data_url, target_path, size_in_bytes).wait() + LOG.info("Downloaded data from [%s] to [%s] using transfer manager.", data_url, target_path) + return + except FileNotFoundError as ex: + if self.test_mode: + raise exceptions.DataError( + "This track does not support test mode. Ask the track author to add it or disable test mode and retry." + ) from None + raise exceptions.DataError(f"Cannot download data from '{data_url}' using transfer manager.") from ex + except Exception as ex: + raise exceptions.DataError(f"Cannot download data from '{data_url}' using transfer manager.") from ex else: - separator = "/" - # join manually as `urllib.parse.urljoin` does not work with S3 or GS URL schemes. - data_url = f"{base_url}{separator}{file_name}" - try: io.ensure_dir(os.path.dirname(target_path)) if size_in_bytes: - self.logger.info("Downloading data from [%s] (%s) to [%s].", data_url, pretty.size(size_in_bytes), target_path) + LOG.info("Downloading data from [%s] (%s) to [%s].", data_url, pretty.size(size_in_bytes), target_path) else: - self.logger.info("Downloading data from [%s] to [%s].", data_url, target_path) + LOG.info("Downloading data from [%s] to [%s].", data_url, target_path) # we want to have a bit more accurate download progress as these files are typically very large progress = net.Progress("[INFO] Downloading track data", accuracy=1) - net.download(data_url, target_path, size_in_bytes, progress_indicator=progress) + try: + net.download(data_url, target_path, size_in_bytes, progress_indicator=progress) + except urllib.error.HTTPError as e: + if e.code == 404 and self.test_mode: + raise exceptions.DataError( + "This track does not support test mode. Ask the track author to add it or disable test mode and retry." + ) from None + + msg = f"Could not download [{data_url}] to [{target_path}]" + if e.reason: + msg += f" (HTTP status: {e.code}, reason: {e.reason})" + else: + msg += f" (HTTP status: {e.code})" + raise exceptions.DataError(msg) from e + except urllib.error.URLError as e: + raise exceptions.DataError(f"Could not download [{data_url}] to [{target_path}].") from e + progress.finish() - self.logger.info("Downloaded data from [%s] to [%s].", data_url, target_path) - except urllib.error.HTTPError as e: - if e.code == 404 and self.test_mode: - raise exceptions.DataError( - "This track does not support test mode. Ask the track author to add it or disable test mode and retry." - ) from None - - msg = f"Could not download [{data_url}] to [{target_path}]" - if e.reason: - msg += f" (HTTP status: {e.code}, reason: {e.reason})" - else: - msg += f" (HTTP status: {e.code})" - raise exceptions.DataError(msg) from e - except urllib.error.URLError as e: - raise exceptions.DataError(f"Could not download [{data_url}] to [{target_path}].") from e + LOG.info("Downloaded data from [%s] to [%s].", data_url, target_path) if not os.path.isfile(target_path): raise exceptions.SystemSetupError( @@ -706,14 +740,13 @@ def __init__(self, base_path, template_file_name, source=io.FileSource, fileglob self.source = source self.fileglobber = fileglobber self.assembled_source = None - self.logger = logging.getLogger(__name__) def load_template_from_file(self): loader = jinja2.FileSystemLoader(self.base_path) try: base_track = loader.get_source(jinja2.Environment(), self.template_file_name) except jinja2.TemplateNotFound: - self.logger.exception("Could not load track from [%s].", self.template_file_name) + LOG.exception("Could not load track from [%s].", self.template_file_name) raise TrackSyntaxError(f"Could not load track from '{self.template_file_name}'") self.assembled_source = self.replace_includes(self.base_path, base_track[0]) @@ -848,7 +881,6 @@ def relative_glob(start, f): class TaskFilterTrackProcessor(TrackProcessor): def __init__(self, cfg: types.Config): - self.logger = logging.getLogger(__name__) include_tasks = cfg.opts("track", "include.tasks", mandatory=False) exclude_tasks = cfg.opts("track", "exclude.tasks", mandatory=False) @@ -904,10 +936,10 @@ def on_after_load_track(self, track): if self._filter_out_match(leaf_task): leafs_to_remove.append(leaf_task) for leaf_task in leafs_to_remove: - self.logger.info("Removing sub-task [%s] from challenge [%s] due to task filter.", leaf_task, challenge) + LOG.info("Removing sub-task [%s] from challenge [%s] due to task filter.", leaf_task, challenge) task.remove_task(leaf_task) for task in tasks_to_remove: - self.logger.info("Removing task [%s] from challenge [%s] due to task filter.", task, challenge) + LOG.info("Removing task [%s] from challenge [%s] due to task filter.", task, challenge) challenge.remove_task(task) return track @@ -915,7 +947,6 @@ def on_after_load_track(self, track): class ServerlessFilterTrackProcessor(TrackProcessor): def __init__(self, cfg: types.Config): - self.logger = logging.getLogger(__name__) self.serverless_mode = convert.to_bool(cfg.opts("driver", "serverless.mode", mandatory=False, default_value=False)) self.serverless_operator = convert.to_bool(cfg.opts("driver", "serverless.operator", mandatory=False, default_value=False)) @@ -924,7 +955,7 @@ def _is_filtered_task(self, operation): return not operation.run_on_serverless if operation.type == "raw-request": - self.logger.info("Treating raw-request operation for operation [%s] as public.", operation.name) + LOG.info("Treating raw-request operation for operation [%s] as public.", operation.name) try: op = track.OperationType.from_hyphenated_string(operation.type) @@ -934,7 +965,7 @@ def _is_filtered_task(self, operation): else: return op.serverless_status < serverless.Status.Public except KeyError: - self.logger.info("Treating user-provided operation type [%s] for operation [%s] as public.", operation.type, operation.name) + LOG.info("Treating user-provided operation type [%s] for operation [%s] as public.", operation.type, operation.name) return False def on_after_load_track(self, track): @@ -962,23 +993,21 @@ def on_after_load_track(self, track): class TestModeTrackProcessor(TrackProcessor): def __init__(self, cfg: types.Config): self.test_mode_enabled = cfg.opts("track", "test.mode.enabled", mandatory=False, default_value=False) - self.logger = logging.getLogger(__name__) def on_after_load_track(self, track): if not self.test_mode_enabled: return track - self.logger.info("Preparing track [%s] for test mode.", str(track)) + LOG.info("Preparing track [%s] for test mode.", str(track)) for corpus in track.corpora: for document_set in corpus.documents: # TODO #341: Should we allow this for snapshots too? if document_set.is_bulk: if document_set.number_of_documents > 1000: - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug( - "Reducing corpus size to 1000 documents in corpus [%s], uncompressed source file [%s]", - corpus.name, - document_set.document_file, - ) + LOG.debug( + "Reducing corpus size to 1000 documents in corpus [%s], uncompressed source file [%s]", + corpus.name, + document_set.document_file, + ) document_set.number_of_documents = 1000 @@ -1000,13 +1029,12 @@ def on_after_load_track(self, track): document_set.compressed_size_in_bytes = None document_set.uncompressed_size_in_bytes = None else: - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug( - "Maintaining existing size of %d documents in corpus [%s], uncompressed source file [%s]", - document_set.number_of_documents, - corpus.name, - document_set.document_file, - ) + LOG.debug( + "Maintaining existing size of %d documents in corpus [%s], uncompressed source file [%s]", + document_set.number_of_documents, + corpus.name, + document_set.document_file, + ) for challenge in track.challenges: for task in challenge.schedule: @@ -1016,26 +1044,18 @@ def on_after_load_track(self, track): # at least one iteration for each client. if leaf_task.warmup_iterations is not None and leaf_task.warmup_iterations > leaf_task.clients: count = leaf_task.clients - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug("Resetting warmup iterations to %d for [%s]", count, str(leaf_task)) + LOG.debug("Resetting warmup iterations to %d for [%s]", count, leaf_task) leaf_task.warmup_iterations = count if leaf_task.iterations is not None and leaf_task.iterations > leaf_task.clients: count = leaf_task.clients - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug("Resetting measurement iterations to %d for [%s]", count, str(leaf_task)) + LOG.debug("Resetting measurement iterations to %d for [%s]", count, leaf_task) leaf_task.iterations = count if leaf_task.warmup_time_period is not None and leaf_task.warmup_time_period > 0: leaf_task.warmup_time_period = 0 - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug( - "Resetting warmup time period for [%s] to [%d] seconds.", str(leaf_task), leaf_task.warmup_time_period - ) + LOG.debug("Resetting warmup time period for [%s] to [%d] seconds.", leaf_task, leaf_task.warmup_time_period) if leaf_task.time_period is not None and leaf_task.time_period > 10: leaf_task.time_period = 10 - if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug( - "Resetting measurement time period for [%s] to [%d] seconds.", str(leaf_task), leaf_task.time_period - ) + LOG.debug("Resetting measurement time period for [%s] to [%d] seconds.", leaf_task, leaf_task.time_period) # Keep throttled to expose any errors but increase the target throughput for short execution times. if leaf_task.target_throughput: @@ -1093,7 +1113,6 @@ def __init__(self, cfg: types.Config): build_flavor=self.build_flavor, serverless_operator=self.serverless_operator, ) - self.logger = logging.getLogger(__name__) def read(self, track_name, track_spec_file, mapping_dir): """ @@ -1105,7 +1124,7 @@ def read(self, track_name, track_spec_file, mapping_dir): :return: A corresponding track instance if the track file is valid. """ - self.logger.info("Reading track specification file [%s].", track_spec_file) + LOG.info("Reading track specification file [%s].", track_spec_file) # render the track to a temporary file instead of dumping it into the logs. It is easier to check for error messages # involving lines numbers and it also does not bloat Rally's log file so much. with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp: @@ -1119,18 +1138,18 @@ def read(self, track_name, track_spec_file, mapping_dir): ) with open(tmp.name, "w", encoding="utf-8") as f: f.write(rendered) - self.logger.info("Final rendered track for '%s' has been written to '%s'.", track_spec_file, tmp.name) + LOG.info("Final rendered track for '%s' has been written to '%s'.", track_spec_file, tmp.name) track_spec = json.loads(rendered) except jinja2.exceptions.TemplateSyntaxError as te: - self.logger.exception("Could not load [%s] due to Jinja Syntax Exception.", track_spec_file) + LOG.exception("Could not load [%s] due to Jinja Syntax Exception.", track_spec_file) msg = f"Could not load '{track_spec_file}' due to Jinja Syntax Exception. " msg += f"The track file ({tmp.name}) likely hasn't been written." raise TrackSyntaxError(msg, te) except jinja2.exceptions.TemplateNotFound: - self.logger.exception("Could not load [%s]", track_spec_file) + LOG.exception("Could not load [%s]", track_spec_file) raise exceptions.SystemSetupError(f"Track {track_name} does not exist") except json.JSONDecodeError as e: - self.logger.exception("Could not load [%s].", track_spec_file) + LOG.exception("Could not load [%s].", track_spec_file) msg = f"Could not load '{track_spec_file}': {str(e)}." if e.doc and e.lineno > 0 and e.colno > 0: line_idx = e.lineno - 1 @@ -1144,7 +1163,7 @@ def read(self, track_name, track_spec_file, mapping_dir): msg += f"The complete track has been written to '{tmp.name}' for diagnosis." raise TrackSyntaxError(msg) except Exception as e: - self.logger.exception("Could not load [%s].", track_spec_file) + LOG.exception("Could not load [%s].", track_spec_file) msg = f"Could not load '{track_spec_file}'. The complete track has been written to '{tmp.name}' for diagnosis." raise TrackSyntaxError(msg, e) @@ -1182,7 +1201,7 @@ def read(self, track_name, track_spec_file, mapping_dir): params_list = ",".join(opts.double_quoted_list_of(sorted(internal_user_defined_track_params))) err_msg = f"Some of your track parameter(s) {params_list} are defined by Rally and cannot be modified.\n" - self.logger.critical(err_msg) + LOG.critical(err_msg) # also dump the message on the console console.println(err_msg) raise exceptions.TrackConfigError(f"Reserved track parameters {sorted(internal_user_defined_track_params)}.") @@ -1210,7 +1229,7 @@ def read(self, track_name, track_spec_file, mapping_dir): ) ) - self.logger.critical(err_msg) + LOG.critical(err_msg) # also dump the message on the console console.println(err_msg) raise exceptions.TrackConfigError(f"Unused track parameters {sorted(unused_user_defined_track_params)}.") @@ -1240,9 +1259,9 @@ def load(self): # every module needs to have a register() method for module in root_modules: module.register(self) - except BaseException: - msg = "Could not register track plugin at [%s]" % self.loader.root_path - logging.getLogger(__name__).exception(msg) + except Exception: + msg = f"Could not register track plugin at [{self.loader.root_path}]" + LOG.exception(msg) raise exceptions.SystemSetupError(msg) def register_param_source(self, name, param_source): @@ -1290,7 +1309,6 @@ def __init__( self.complete_track_params = complete_track_params self.selected_challenge = selected_challenge self.source = source - self.logger = logging.getLogger(__name__) def __call__(self, track_name, track_specification, mapping_dir, spec_file=None): self.name = track_name @@ -1410,7 +1428,7 @@ def _create_index_template(self, tpl_spec, mapping_dir): return track.IndexTemplate(name, index_pattern, template_content, delete_matching_indices) def _load_template(self, contents, description): - self.logger.info("Loading template [%s].", description) + LOG.info("Loading template [%s].", description) register_all_params_in_track(contents, self.complete_track_params) try: rendered = render_template( @@ -1420,7 +1438,7 @@ def _load_template(self, contents, description): ) return json.loads(rendered) except Exception as e: - self.logger.exception("Could not load file template for %s.", description) + LOG.exception("Could not load file template for %s.", description) raise TrackSyntaxError("Could not load file template for '%s'" % description, str(e)) def _create_corpora(self, corpora_specs, indices, data_streams): @@ -1807,9 +1825,9 @@ def parse_operation(self, op_spec, error_ctx="operations"): op = track.OperationType.from_hyphenated_string(op_type_name) if "include-in-reporting" not in params: params["include-in-reporting"] = not op.admin_op - self.logger.debug("Using built-in operation type [%s] for operation [%s].", op_type_name, op_name) + LOG.debug("Using built-in operation type [%s] for operation [%s].", op_type_name, op_name) except KeyError: - self.logger.info("Using user-provided operation type [%s] for operation [%s].", op_type_name, op_name) + LOG.info("Using user-provided operation type [%s] for operation [%s].", op_type_name, op_name) try: return track.Operation(name=op_name, meta_data=meta_data, operation_type=op_type_name, params=params, param_source=param_source) diff --git a/esrally/types.py b/esrally/types.py index 88e15334f..87d06a53f 100644 --- a/esrally/types.py +++ b/esrally/types.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Literal, Protocol, TypeVar +from typing import Any, Literal, Protocol, runtime_checkable Section = Literal[ "actor", @@ -158,12 +158,13 @@ "src.root.dir", "storage.adapters", "storage.aws.profile", - "storage.http.chunk_size", + "storage.cache_ttl", + "storage.chunk_size", "storage.http.max_retries", "storage.local_dir", "storage.max_connections", "storage.max_workers", - "storage.mirrors_files", + "storage.mirror_files", "storage.monitor_interval", "storage.multipart_size", "storage.random_seed", @@ -173,22 +174,28 @@ "team.repository.dir", "test.mode.enabled", "time.start", + "track.downloader.multipart_enabled", "track.name", "track.path", "track.repository.dir", "user.tags", "values", ] -_Config = TypeVar("_Config", bound="Config") +@runtime_checkable class Config(Protocol): + + name: str | None = None + def add(self, scope, section: Section, key: Key, value: Any) -> None: ... - def add_all(self, source: _Config, section: Section) -> None: ... + def add_all(self, source: "Config", section: Section) -> None: ... def opts(self, section: Section, key: Key, default_value=None, mandatory: bool = True) -> Any: ... + def all_sections(self) -> list[Section]: ... + def all_opts(self, section: Section) -> dict: ... def exists(self, section: Section, key: Key) -> bool: ... diff --git a/esrally/utils/convert.py b/esrally/utils/convert.py index 9d2a8a40f..a2a94fa64 100644 --- a/esrally/utils/convert.py +++ b/esrally/utils/convert.py @@ -16,7 +16,7 @@ # under the License. import enum import math -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import TypeVar @@ -230,3 +230,12 @@ def to_bool(value: str | bool) -> bool: elif value in ["False", "false", "No", "no", "f", "n", "0", False]: return False raise ValueError(f"Cannot convert [{value}] to bool.") + + +def to_strings(values: Iterable[str] | None) -> tuple[str, ...]: + if not values: + return tuple() + if isinstance(values, str): + # It parses adapter names when it has been defined as a single string. + return tuple(str(values).replace(" ", "").split(",")) + return tuple(values) diff --git a/esrally/utils/threads.py b/esrally/utils/threads.py index 997ed0586..174fbd94a 100644 --- a/esrally/utils/threads.py +++ b/esrally/utils/threads.py @@ -37,7 +37,7 @@ def __init__(self, interval: float, function: Callable[[], Any], name: str | Non super().__init__(name=name, daemon=daemon) if interval <= 0: raise ValueError("interval must be strictly positive") - self._interval = interval + self.interval = interval self._function = function self._finished = threading.Event() @@ -47,10 +47,10 @@ def cancel(self): def run(self): """It executes the function every interval seconds until the timer is cancelled.""" - self._finished.wait(self._interval) + self._finished.wait(self.interval) while not self._finished.is_set(): self._function() - self._finished.wait(self._interval) + self._finished.wait(self.interval) def wait(self, timeout: float | None) -> bool: return self._finished.wait(timeout=timeout) diff --git a/tests/driver/driver_test.py b/tests/driver/driver_test.py index 70cb3787c..b35147ed6 100644 --- a/tests/driver/driver_test.py +++ b/tests/driver/driver_test.py @@ -1676,7 +1676,7 @@ async def perform_request(*args, **kwargs): test_track = track.Track(name="unittest", description="unittest track", indices=None, challenges=None) # in one second (0.5 warmup + 0.5 measurement) we should get 1000 [ops/s] / 4 [clients] = 250 samples - for target_throughput, bounds in {10: [2, 4], 100: [24, 26], 1000: [235, 255]}.items(): + for target_throughput, bounds in {10: [2, 4], 100: [24, 26], 1000: [220, 280]}.items(): task = track.Task( "time-based", track.Operation( diff --git a/tests/storage/adapter_test.py b/tests/storage/adapter_test.py index 38b930df5..084f85210 100644 --- a/tests/storage/adapter_test.py +++ b/tests/storage/adapter_test.py @@ -18,22 +18,21 @@ from abc import ABC from dataclasses import dataclass -from typing import Any from unittest import mock import pytest from typing_extensions import Self -import esrally.config +from esrally import types +from esrally.storage import StorageConfig from esrally.storage._adapter import Adapter, AdapterRegistry -from esrally.types import Config from esrally.utils.cases import cases class MockAdapter(Adapter, ABC): @classmethod - def from_config(cls, cfg: Config, **kwargs: Any) -> Self: + def from_config(cls, cfg: types.Config) -> Self: return mock.create_autospec(cls, spec_set=True, instance=True) @@ -66,20 +65,20 @@ def match_url(cls, url: str) -> bool: @pytest.fixture() -def cfg() -> Config: - cfg = esrally.config.Config() - cfg.add( - esrally.config.Scope.application, - "storage", - "storage.adapters", - f"{__name__}:ExampleAdapterWithPath,{__name__}:ExampleAdapter,{__name__}:HTTPSAdapter,{__name__}:HTTPAdapter", +def cfg() -> types.Config: + cfg = StorageConfig() + cfg.adapters = ( + f"{__name__}:ExampleAdapterWithPath", + f"{__name__}:ExampleAdapter", + f"{__name__}:HTTPSAdapter", + f"{__name__}:HTTPAdapter", ) return cfg # Initialize default registry @pytest.fixture() -def registry(cfg: Config) -> AdapterRegistry: +def registry(cfg: types.Config) -> AdapterRegistry: return AdapterRegistry.from_config(cfg) diff --git a/tests/storage/aws_test.py b/tests/storage/aws_test.py index e517e1ad5..2575379d1 100644 --- a/tests/storage/aws_test.py +++ b/tests/storage/aws_test.py @@ -16,20 +16,18 @@ # under the License. from __future__ import annotations +import dataclasses from collections.abc import Iterable -from dataclasses import dataclass from typing import Any -from unittest.mock import call, create_autospec +from unittest import mock import boto3 import pytest -from esrally.config import Config, Scope -from esrally.storage._adapter import Head, Writable -from esrally.storage._aws import S3Adapter, S3Client, head_from_response -from esrally.storage._http import CHUNK_SIZE +from esrally.storage._adapter import Head +from esrally.storage._config import StorageConfig from esrally.storage._range import rangeset -from esrally.types import Key +from esrally.storage.aws import S3Adapter, S3Client, head_from_response from esrally.utils.cases import cases SOME_BUCKET = "some-example" @@ -41,7 +39,7 @@ CONTENT_RANGE_HEADERS = {"ContentRange": "bytes 3-20/128", "ContentLength": 18} -@dataclass() +@dataclasses.dataclass class HeadCase: response: dict[str, Any] want: Head @@ -53,7 +51,7 @@ class HeadCase: @pytest.fixture def s3_client() -> S3Client: cls = type(boto3.Session().client("s3")) - return create_autospec(cls, instance=True) + return mock.create_autospec(cls, instance=True) @cases( @@ -87,24 +85,24 @@ def iter_chunks(self, chunk_size: int) -> Iterable[bytes]: SOME_DATA_HEADERS = {"ContentLength": len(SOME_DATA), "Body": DummyBody(SOME_DATA)} -@dataclass() +@dataclasses.dataclass class GetCase: response: dict[str, Any] - want: Head + want_head: Head content_length: int | None = None ranges: str = "" url: str = SOME_URL want_bucket: str = SOME_BUCKET want_key: str = SOME_KEY want_range: str = "" - want_write_data: Iterable[bytes] = tuple() + want_data: list[bytes] = dataclasses.field(default_factory=list) @cases( empty=GetCase({}, Head(SOME_URL)), accept_ranges=GetCase(ACCEPT_RANGES_HEADERS, Head(SOME_URL, accept_ranges=True)), content_length=GetCase(CONTENT_LENGTH_HEADERS, Head(SOME_URL, content_length=512)), - read_data=GetCase(SOME_DATA_HEADERS, Head(SOME_URL, content_length=len(SOME_DATA)), want_write_data=[SOME_DATA]), + read_data=GetCase(SOME_DATA_HEADERS, Head(SOME_URL, content_length=len(SOME_DATA)), want_data=[SOME_DATA]), ranges=GetCase( CONTENT_RANGE_HEADERS, Head(SOME_URL, content_length=18, ranges=rangeset("3-20"), document_length=128), @@ -116,17 +114,17 @@ def test_get(case: GetCase, s3_client) -> None: case.response.setdefault("Body", DummyBody(b"")) s3_client.get_object.return_value = case.response adapter = S3Adapter(s3_client=s3_client) - stream = create_autospec(Writable, spec_set=True, instance=True) - head = adapter.get(case.url, stream, head=Head(content_length=case.content_length, ranges=rangeset(case.ranges))) - assert head == case.want + head, chunks = adapter.get(case.url, check_head=Head(content_length=case.content_length, ranges=rangeset(case.ranges))) + assert head == case.want_head + assert case.want_data == list(chunks) + kwargs = {} if case.want_range: kwargs["Range"] = f"bytes={case.ranges}" s3_client.get_object.assert_called_once_with(Bucket=case.want_bucket, Key=case.want_key, **kwargs) - assert [call(data) for data in case.want_write_data] == stream.write.mock_calls -@dataclass() +@dataclasses.dataclass class HeadFromResponseCase: response: dict[str, Any] want_head: Head | None = None @@ -151,23 +149,27 @@ def test_head_from_response(case: HeadFromResponseCase): assert got == case.want_head -@dataclass() +@dataclasses.dataclass class FromConfigCase: - opts: dict[Key, str] - want_aws_profile: str = None - want_chunk_size: int = CHUNK_SIZE + cfg: StorageConfig + want_aws_profile: str | None = None + want_chunk_size: int = StorageConfig.DEFAULT_CHUNK_SIZE + + +def storage_config(**kwargs: Any) -> StorageConfig: + cfg = StorageConfig() + for k, v in kwargs.items(): + setattr(cfg, k, v) + return cfg @cases( - default=FromConfigCase({}), - chunk_size=FromConfigCase({"storage.http.chunk_size": "10"}, want_chunk_size=10), - aws_profile=FromConfigCase({"storage.aws.profile": "foo"}, want_aws_profile="foo"), + default=FromConfigCase(storage_config()), + chunk_size=FromConfigCase(storage_config(chunk_size=10), want_chunk_size=10), + aws_profile=FromConfigCase(storage_config(aws_profile="foo"), want_aws_profile="foo"), ) def test_from_config(case: FromConfigCase) -> None: - cfg = Config() - for k, v in case.opts.items(): - cfg.add(Scope.application, "storage", k, v) - adapter = S3Adapter.from_config(cfg) + adapter = S3Adapter.from_config(case.cfg) assert isinstance(adapter, S3Adapter) assert adapter.chunk_size == case.want_chunk_size assert adapter.aws_profile == case.want_aws_profile diff --git a/tests/storage/client_test.py b/tests/storage/client_test.py index c2782e0f3..ab2a5c0a4 100644 --- a/tests/storage/client_test.py +++ b/tests/storage/client_test.py @@ -16,23 +16,22 @@ # under the License. from __future__ import annotations +import collections +import dataclasses import json import os import random -from collections import defaultdict from collections.abc import Iterator -from dataclasses import dataclass from os import PathLike -from unittest.mock import create_autospec +from typing import Any import pytest -from esrally.config import Config, Scope -from esrally.storage._adapter import Head, Writable -from esrally.storage._client import MAX_CONNECTIONS, Client +from esrally.config import Config +from esrally.storage._adapter import DummyAdapter, Head +from esrally.storage._client import CachedHeadError, Client +from esrally.storage._config import StorageConfig from esrally.storage._range import NO_RANGE, RangeSet, rangeset -from esrally.storage.testing import DummyAdapter -from esrally.types import Key from esrally.utils.cases import cases BASE_URL = "https://example.com" @@ -78,9 +77,18 @@ } +def default_heads() -> dict[str, Head]: + return {h.url: h for h in HEADS if h.url is not None} + + +def default_data() -> dict[str, bytes]: + return collections.defaultdict(lambda: SOME_BODY) + + +@dataclasses.dataclass class StorageAdapter(DummyAdapter): - HEADS = HEADS - DATA: dict[str, bytes] = defaultdict(lambda: SOME_BODY) + heads: dict[str, Head] = dataclasses.field(default_factory=default_heads) + data: dict[str, bytes] = dataclasses.field(default_factory=default_data) @pytest.fixture(scope="function") @@ -95,10 +103,10 @@ def mirror_files(tmpdir: PathLike) -> Iterator[str]: @pytest.fixture def cfg(mirror_files: str) -> Config: - cfg = Config() - cfg.add(Scope.application, "storage", "storage.mirrors_files", mirror_files) - cfg.add(Scope.application, "storage", "storage.random_seed", 42) - cfg.add(Scope.application, "storage", "storage.adapters", f"{__name__}:StorageAdapter") + cfg = StorageConfig() + cfg.mirror_files = mirror_files + cfg.random_seed = 42 + cfg.adapters = (f"{__name__}:StorageAdapter",) return cfg @@ -107,52 +115,81 @@ def client(cfg: Config) -> Client: return Client.from_config(cfg) -@dataclass() -class HeadCase: - url: str - want: Head - ttl: float | None = None +def storage_config(**kwargs: Any) -> StorageConfig: + cfg = StorageConfig() + for k, v in kwargs.items(): + setattr(cfg, k, v) + return cfg -@dataclass() +@dataclasses.dataclass class FromConfigCase: - opts: dict[Key, str] - want_max_connections: int = MAX_CONNECTIONS + cfg: StorageConfig + want_max_connections: int = StorageConfig.DEFAULT_MAX_CONNECTIONS want_random: random.Random | None = None @cases( - default=FromConfigCase({}), - max_connections=FromConfigCase({"storage.max_connections": "2"}, want_max_connections=2), - random_seed=FromConfigCase({"storage.random_seed": "42"}, want_random=random.Random("42")), + default=FromConfigCase(storage_config()), + max_connections=FromConfigCase(storage_config(max_connections=42), want_max_connections=42), + random_seed=FromConfigCase(storage_config(random_seed="24"), want_random=random.Random("24")), ) def test_from_config(case: FromConfigCase) -> None: # pylint: disable=protected-access - cfg = Config() - for k, v in case.opts.items(): - cfg.add(Scope.application, "storage", k, v) - client = Client.from_config(cfg) + client = Client.from_config(case.cfg) assert isinstance(client, Client) assert client._connections["some"].max_count == case.want_max_connections if case.want_random is not None: assert client._random.random() == case.want_random.random() -@cases(default=HeadCase(SOME_URL, SOME_HEAD), ttl=HeadCase(SOME_URL, SOME_HEAD, ttl=300.0)) +@dataclasses.dataclass +class HeadCase: + url: str + want_head: Head | None = None + want_error: type[Exception] | None = None + cache_ttl: float | None = None + want_cached: bool | None = None + + +@cases( + default=HeadCase(SOME_URL, SOME_HEAD), + cache_ttl=HeadCase(SOME_URL, SOME_HEAD, cache_ttl=300.0, want_cached=True), + no_cache_ttl=HeadCase(SOME_URL, SOME_HEAD, cache_ttl=0.0, want_cached=False), + negative_cache_ttl=HeadCase(SOME_URL, SOME_HEAD, cache_ttl=-1.0, want_cached=False), + error=HeadCase(NOT_FOUND_BASE_URL, want_error=FileNotFoundError), + error_cache_ttl=HeadCase(NOT_FOUND_BASE_URL, want_error=FileNotFoundError, cache_ttl=300.0, want_cached=True), + error_no_cache_ttl=HeadCase(NOT_FOUND_BASE_URL, want_error=FileNotFoundError, cache_ttl=0.0, want_cached=False), + error_negagive_cache_ttl=HeadCase(NOT_FOUND_BASE_URL, want_error=FileNotFoundError, cache_ttl=-1.0, want_cached=False), +) def test_head(case: HeadCase, client: Client) -> None: - got = client.head(url=case.url, ttl=case.ttl) - assert got == case.want - if case.ttl is not None: - assert got is client.head(url=case.url, ttl=case.ttl) + # pylint: disable=catching-non-exception + err: Exception | None = None + want_error = tuple(filter(None, [case.want_error])) + try: + got = client.head(url=case.url, cache_ttl=case.cache_ttl) + except want_error as e: + got = None + err = e + else: + assert got == case.want_head + if case.want_cached is not None: + try: + assert (got is client.head(url=case.url, cache_ttl=case.cache_ttl)) == case.want_cached + except CachedHeadError as ex: + assert case.want_cached + assert err is ex.__cause__ + except want_error: + assert not case.want_cached -@dataclass() +@dataclasses.dataclass class ResolveCase: url: str want: list[Head] content_length: int | None = None accept_ranges: bool | None = None - ttl: float = 60.0 + cache_ttl: float = 60.0 @cases( @@ -164,41 +201,41 @@ class ResolveCase: mismatching_document_length=ResolveCase(url=MIRRORING_URL, content_length=10, want=[]), accept_ranges=ResolveCase(url=MIRRORING_URL, accept_ranges=True, want=[MIRRORING_HEAD, MIRRORED_HEAD]), reject_ranges=ResolveCase(url=NO_RANGES_URL, accept_ranges=True, want=[]), - zero_ttl=ResolveCase(url=SOME_URL, ttl=0.0, want=[SOME_HEAD]), + zero_ttl=ResolveCase(url=SOME_URL, cache_ttl=0.0, want=[SOME_HEAD]), ) def test_resolve(case: ResolveCase, client: Client) -> None: - check = Head(content_length=case.content_length, accept_ranges=case.accept_ranges) - got = sorted(client.resolve(case.url, check=check, ttl=case.ttl), key=lambda h: str(h.url)) + check_head = Head(content_length=case.content_length, accept_ranges=case.accept_ranges) + got = sorted(client.resolve(case.url, check_head=check_head, cache_ttl=case.cache_ttl), key=lambda h: str(h.url)) want = sorted(case.want, key=lambda h: str(h.url)) assert got == want, "unexpected resolve result" for g in got: assert g.url is not None, "unexpected resolve result" - if case.ttl > 0.0: - assert g is client.head(url=g.url, ttl=case.ttl), "obtained head wasn't cached" + if case.cache_ttl > 0.0: + assert g is client.head(url=g.url, cache_ttl=case.cache_ttl), "obtained head wasn't cached" else: - assert g is not client.head(url=g.url, ttl=case.ttl), "obtained head was cached" + assert g is not client.head(url=g.url, cache_ttl=case.cache_ttl), "obtained head was cached" -@dataclass() +@dataclasses.dataclass class GetCase: url: str want_any: list[Head] ranges: RangeSet = NO_RANGE document_length: int = None - want_data: bytes | None = None + want_data: list[bytes] = dataclasses.field(default_factory=list) @cases( regular=GetCase( SOME_URL, [Head(url=SOME_URL, content_length=len(SOME_BODY), document_length=len(SOME_BODY))], - want_data=SOME_BODY, + want_data=[SOME_BODY], ), range=GetCase( SOME_URL, [Head(SOME_URL, content_length=30, accept_ranges=True, ranges=rangeset("0-29"), document_length=len(SOME_BODY))], ranges=rangeset("0-29"), - want_data=SOME_BODY, + want_data=[SOME_BODY], ), mirrors=GetCase( MIRRORING_URL, @@ -206,15 +243,13 @@ class GetCase: MIRRORED_HEAD, MIRRORED_NO_RANGE_HEAD, ], - want_data=SOME_BODY, + want_data=[SOME_BODY], ), ) def test_get(case: GetCase, client: Client) -> None: - stream = create_autospec(Writable, spec_set=True, instance=True) - got = client.get(case.url, stream, head=Head(ranges=case.ranges, document_length=case.document_length)) - assert [] != check_any(got, case.want_any) - if case.want_data is not None: - stream.write.assert_called_once_with(case.want_data) + head, chunks = client.get(case.url, check_head=Head(ranges=case.ranges, document_length=case.document_length)) + assert [] != check_any(head, case.want_any) + assert case.want_data == list(chunks) def check_any(head: Head, any_head: list[Head]) -> list[Head]: diff --git a/tests/storage/http_test.py b/tests/storage/http_test.py index c66af7b2a..3a4d6e2ba 100644 --- a/tests/storage/http_test.py +++ b/tests/storage/http_test.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import dataclasses import io from dataclasses import dataclass from typing import Any @@ -25,17 +26,10 @@ from requests import Response, Session from requests.structures import CaseInsensitiveDict -from esrally.config import Config, Scope -from esrally.storage._adapter import Head, Writable -from esrally.storage._http import ( - CHUNK_SIZE, - MAX_RETRIES, - HTTPAdapter, - head_from_headers, - ranges_to_headers, -) +from esrally.storage._adapter import Head +from esrally.storage._config import StorageConfig from esrally.storage._range import rangeset -from esrally.types import Key +from esrally.storage.http import HTTPAdapter, head_from_headers, ranges_to_headers from esrally.utils.cases import cases URL = "https://example.com" @@ -45,7 +39,6 @@ CONTENT_RANGE_HEADER = {"Content-Range": "bytes 3-20/128", "Content-Length": "18"} X_GOOG_HASH_CRC32C_HEADER = {"X-Goog-Hash": "crc32c=some-checksum"} X_AMZ_CHECKSUM_CRC32C_HEADER = {"x-amz-checksum-crc32c": "some-checksum"} -DATA = "some-data" @pytest.fixture() @@ -56,10 +49,10 @@ def session() -> Session: def response( headers: dict[str, str] | None = None, status_code: int = 200, - data: str = "", + data: bytes = b"", ): res = Response() - res.raw = io.StringIO(data) + res.raw = io.BytesIO(data) res.status_code = status_code res.headers = CaseInsensitiveDict() if headers is not None: @@ -92,13 +85,13 @@ def test_head(case: HeadCase, session: Session) -> None: assert head == case.want -@dataclass() +@dataclasses.dataclass class GetCase: response: Response - want: Head + want_head: Head url: str = URL ranges: str = "" - want_data: str = "" + want_data: list[bytes] = dataclasses.field(default_factory=list) want_request_range: str = "" @@ -106,7 +99,7 @@ class GetCase: default=GetCase(response(), Head(URL)), accept_ranges=GetCase(response(ACCEPT_RANGES_HEADER), Head(URL, accept_ranges=True)), content_length=GetCase(response(CONTENT_LENGTH_HEADER), Head(URL, content_length=512)), - read_data=GetCase(response(data="some_data"), Head(URL), want_data="some_data"), + read_data=GetCase(response(data=b"some_data"), Head(URL), want_data=[b"some_data"]), ranges=GetCase( response(CONTENT_RANGE_HEADER), Head(URL, content_length=18, ranges=rangeset("3-20"), document_length=128), @@ -119,13 +112,9 @@ class GetCase: def test_get(case: GetCase, session: Session) -> None: adapter = HTTPAdapter(session=session) session.get.return_value = case.response - stream = create_autospec(Writable, spec_set=True, instance=True) - head = adapter.get(case.url, stream, head=Head(ranges=rangeset(case.ranges))) - assert head == case.want - if case.want_data: - stream.write.assert_called_once_with(case.want_data) - else: - assert not stream.write.called + head, data = adapter.get(case.url, check_head=Head(ranges=rangeset(case.ranges))) + assert head == case.want_head + assert list(data) == case.want_data want_request_headers = {} if case.want_request_range: want_request_headers["range"] = case.want_request_range @@ -147,7 +136,6 @@ class RangesToHeadersCase: multipart=RangesToHeadersCase("1-5,7-10", want_errors=(NotImplementedError,)), ) def test_ranges_to_headers(case: RangesToHeadersCase) -> None: - # pylint: disable=protected-access got: dict[str, Any] = {} try: ranges_to_headers(rangeset(case.ranges), got) @@ -176,7 +164,6 @@ class HeadFromHeadersCase: x_amz_checksum=HeadFromHeadersCase(X_AMZ_CHECKSUM_CRC32C_HEADER, Head(URL, crc32c="some-checksum")), ) def test_head_from_headers(case: HeadFromHeadersCase): - # pylint: disable=protected-access try: got = head_from_headers(url=case.url, headers=case.headers) except Exception as ex: @@ -187,27 +174,31 @@ def test_head_from_headers(case: HeadFromHeadersCase): assert got == case.want +def storage_config(**kwargs: Any) -> StorageConfig: + cfg = StorageConfig() + for k, v in kwargs.items(): + setattr(cfg, k, v) + return cfg + + @dataclass() class FromConfigCase: - opts: dict[Key, str] - want_chunk_size: int = CHUNK_SIZE - want_max_retries: int = MAX_RETRIES + cfg: StorageConfig + want_chunk_size: int = StorageConfig.DEFAULT_CHUNK_SIZE + want_max_retries: int = StorageConfig.DEFAULT_MAX_RETRIES want_backoff_factor: int = 0 @cases( - default=FromConfigCase({}), - chunk_size=FromConfigCase({"storage.http.chunk_size": "10"}, want_chunk_size=10), - max_retries=FromConfigCase({"storage.http.max_retries": "3"}, want_max_retries=3), + default=FromConfigCase(storage_config()), + chunk_size=FromConfigCase(storage_config(chunk_size=10), want_chunk_size=10), + max_retries=FromConfigCase(storage_config(max_retries="3"), want_max_retries=3), max_retries_yml=FromConfigCase( - {"storage.http.max_retries": '{"total": 5, "backoff_factor": 5}'}, want_max_retries=5, want_backoff_factor=5 + storage_config(max_retries='{"total": 5, "backoff_factor": 5}'), want_max_retries=5, want_backoff_factor=5 ), ) def test_from_config(case: FromConfigCase) -> None: - cfg = Config() - for k, v in case.opts.items(): - cfg.add(Scope.application, "storage", k, v) - adapter = HTTPAdapter.from_config(cfg) + adapter = HTTPAdapter.from_config(case.cfg) assert isinstance(adapter, HTTPAdapter) assert adapter.chunk_size == case.want_chunk_size retry = adapter.session.adapters["https://"].max_retries diff --git a/tests/storage/manager_test.py b/tests/storage/manager_test.py index 6d3f0446a..db8372c83 100644 --- a/tests/storage/manager_test.py +++ b/tests/storage/manager_test.py @@ -16,42 +16,39 @@ # under the License. from __future__ import annotations +import dataclasses import os -from collections.abc import Iterator -from dataclasses import dataclass +from collections.abc import Generator from typing import Any -from unittest.mock import patch import pytest -from esrally import config, types -from esrally.storage._adapter import Head +from esrally import types +from esrally.storage import _manager +from esrally.storage._adapter import DummyAdapter, Head +from esrally.storage._config import StorageConfig +from esrally.storage._executor import DummyExecutor from esrally.storage._manager import ( TransferManager, + get_transfer_manager, init_transfer_manager, - quit_transfer_manager, - transfer_manager, + shutdown_transfer_manager, ) -from esrally.storage.testing import DummyAdapter, DummyExecutor from esrally.utils.cases import cases -@pytest.fixture -def cfg(tmpdir: os.PathLike) -> types.Config: - cfg = config.Config() - cfg.add( - config.Scope.application, - "storage", - "storage.adapters", - f"{__name__}:StorageAdapter", - ) - cfg.add(config.Scope.application, "storage", "storage.local_dir", str(tmpdir)) +@pytest.fixture(scope="function") +def cfg(request, tmpdir: os.PathLike) -> types.Config: + cfg = StorageConfig() + cfg.adapters = (f"{__name__}:StorageAdapter",) + cfg.local_dir = str(tmpdir) return cfg -@pytest.fixture -def executor() -> Iterator[DummyExecutor]: +@pytest.fixture(scope="function") +def dummy_executor(monkeypatch: pytest.MonkeyPatch) -> Generator[DummyExecutor]: executor = DummyExecutor() + monkeypatch.setattr(_manager, "executor_from_config", lambda cfg: executor) try: yield executor finally: @@ -59,12 +56,12 @@ def executor() -> Iterator[DummyExecutor]: @pytest.fixture -def manager(cfg: types.Config, executor: DummyExecutor) -> Iterator[TransferManager]: - manager = TransferManager.from_config(cfg, executor=executor) +def manager(cfg: types.Config, dummy_executor: DummyExecutor) -> Generator[TransferManager]: + manager = init_transfer_manager(cfg=cfg) try: yield manager finally: - manager.shutdown() + shutdown_transfer_manager() SIMPLE_URL = "http://example.com" @@ -72,15 +69,14 @@ def manager(cfg: types.Config, executor: DummyExecutor) -> Iterator[TransferMana SIMPLE_HEAD = Head(url=SIMPLE_URL, content_length=len(SIMPLE_DATA)) +@dataclasses.dataclass class StorageAdapter(DummyAdapter): - HEADS = (SIMPLE_HEAD,) - DATA = { - SIMPLE_URL: SIMPLE_DATA, - } + heads: dict[str, Head] = dataclasses.field(default_factory=lambda: {SIMPLE_URL: SIMPLE_HEAD}) + data: dict[str, bytes] = dataclasses.field(default_factory=lambda: {SIMPLE_URL: SIMPLE_DATA}) -@dataclass +@dataclasses.dataclass class GetCase: url: str path: os.PathLike | str | None = None @@ -95,7 +91,7 @@ class GetCase: document_length=GetCase(url=SIMPLE_URL, want_data=SIMPLE_DATA, document_length=len(SIMPLE_DATA)), mismach_document_length=GetCase(url=SIMPLE_URL, want_error=(ValueError,), document_length=len(SIMPLE_DATA) - 1), ) -def test_get(case: GetCase, manager: TransferManager, executor: DummyExecutor, tmpdir: os.PathLike) -> None: +def test_get(case: GetCase, manager: TransferManager, dummy_executor: DummyExecutor, tmpdir: os.PathLike) -> None: kwargs: dict[str, Any] = {} if case.path is not None: kwargs["path"] = os.path.join(tmpdir, case.path) @@ -115,7 +111,7 @@ def test_get(case: GetCase, manager: TransferManager, executor: DummyExecutor, t if case.path is not None: assert os.path.join(tmpdir, case.path) == tr.path - executor.execute_tasks() + dummy_executor.execute_tasks() got = tr.wait(timeout=0.0) assert got @@ -128,29 +124,25 @@ def test_get(case: GetCase, manager: TransferManager, executor: DummyExecutor, t assert os.path.getsize(tr.path) == case.document_length -@patch("esrally.storage._manager._MANAGER", None) -def test_global_transfer_manager(cfg: types.Config, tmpdir: os.PathLike) -> None: - with pytest.raises(RuntimeError) as excinfo: - assert transfer_manager() is None - assert str(excinfo.value) == "Transfer manager not initialized." +@pytest.fixture(scope="function", autouse=True) +def cleanup_transfer_manager(): + shutdown_transfer_manager() + yield + shutdown_transfer_manager() + - init_transfer_manager(cfg) - got = transfer_manager() - assert isinstance(got, TransferManager) +def test_transfer_manager(tmpdir: os.PathLike, cfg: StorageConfig) -> None: + manager = init_transfer_manager(cfg=cfg) + assert isinstance(manager, TransferManager) - tr = got.get(url=SIMPLE_URL, document_length=len(SIMPLE_DATA)) - assert tr.wait(timeout=60.0) + tr = manager.get(url=SIMPLE_URL, document_length=len(SIMPLE_DATA)) + assert tr.wait(timeout=10.0) assert os.path.exists(tr.path) + assert os.path.getsize(tr.path) == len(SIMPLE_DATA) - init_transfer_manager(cfg) - assert got is transfer_manager() + assert manager is get_transfer_manager() - quit_transfer_manager() - with pytest.raises(RuntimeError) as excinfo: - assert transfer_manager() is None - assert str(excinfo.value) == "Transfer manager not initialized." + shutdown_transfer_manager() - quit_transfer_manager() - with pytest.raises(RuntimeError) as excinfo: - assert transfer_manager() is None - assert str(excinfo.value) == "Transfer manager not initialized." + with pytest.raises(RuntimeError): + assert get_transfer_manager() diff --git a/tests/storage/mirror_test.py b/tests/storage/mirror_test.py index 1d056440e..8e0f1281a 100644 --- a/tests/storage/mirror_test.py +++ b/tests/storage/mirror_test.py @@ -16,13 +16,14 @@ # under the License. from __future__ import annotations +import dataclasses import os.path -from dataclasses import dataclass -from esrally.config import Config, Scope +import pytest + +from esrally.storage._config import StorageConfig from esrally.storage._mirror import MirrorList -from esrally.types import Key -from esrally.utils.cases import cases +from esrally.utils import cases BASE_URL = "https://rally-tracks.elastic.co" SOME_PATH = "some/file.json.bz2" @@ -35,27 +36,36 @@ MIRROR_FILES = os.path.join(os.path.dirname(__file__), "mirrors.json") -@dataclass() +def storage_config(**kwargs) -> StorageConfig: + cfg = StorageConfig() + for k, v in kwargs.items(): + setattr(cfg, k, v) + return cfg + + +@dataclasses.dataclass() class FromConfigCase: - opts: dict[Key, str] + cfg: StorageConfig want_error: type[Exception] | None = None + want_mirror_files: set[str] = dataclasses.field(default_factory=lambda: set(StorageConfig.DEFAULT_MIRROR_FILES)) want_urls: dict[str, set[str]] | None = None -@cases( - default=FromConfigCase({}, want_urls={}), +@pytest.fixture(autouse=True) +def patch_default_config(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(StorageConfig, "DEFAULT_MIRROR_FILES", tuple()) + + +@cases.cases( + default=FromConfigCase(storage_config(), want_urls={}), mirror_files=FromConfigCase( - {"storage.mirrors_files": MIRROR_FILES}, want_urls={f"{BASE_URL}/": {f"{MIRROR1_URL}/", f"{MIRROR2_URL}/"}} + storage_config(mirror_files=MIRROR_FILES), want_urls={f"{BASE_URL}/": {f"{MIRROR1_URL}/", f"{MIRROR2_URL}/"}} ), - invalid_mirror_files=FromConfigCase({"storage.mirrors_files": ""}, want_error=FileNotFoundError), + invalid_mirror_files=FromConfigCase(storage_config(mirror_files=""), want_urls={}), ) -def test_from_config(case: FromConfigCase): - cfg = Config() - for k, v in case.opts.items(): - cfg.add(Scope.application, "storage", k, v) - +def test_from_config(case: FromConfigCase, monkeypatch): try: - got_mirrors = MirrorList.from_config(cfg) + got_mirrors = MirrorList.from_config(case.cfg) got_error = None except Exception as ex: got_mirrors = None @@ -69,14 +79,14 @@ def test_from_config(case: FromConfigCase): assert isinstance(got_error, case.want_error) -@dataclass() +@dataclasses.dataclass() class ResolveCase: url: str want: list[str] | None = None want_error: type[Exception] | None = None -@cases( +@cases.cases( empty=ResolveCase("", want_error=Exception), simple=ResolveCase(URL, want=[f"{MIRROR1_URL}/{SOME_PATH}", f"{MIRROR2_URL}/{SOME_PATH}"]), normalized=ResolveCase("https://rally-tracks.elastic.co/", want=[f"{MIRROR1_URL}/", f"{MIRROR2_URL}/"]), diff --git a/tests/storage/transfer_test.py b/tests/storage/transfer_test.py index 7d3ed8bf5..c5003f661 100644 --- a/tests/storage/transfer_test.py +++ b/tests/storage/transfer_test.py @@ -25,11 +25,12 @@ import pytest from esrally.config import Config -from esrally.storage._adapter import Head, Writable +from esrally.storage._adapter import Head from esrally.storage._client import Client +from esrally.storage._config import StorageConfig +from esrally.storage._executor import DummyExecutor from esrally.storage._range import rangeset -from esrally.storage._transfer import MAX_CONNECTIONS, Transfer -from esrally.storage.testing import DummyExecutor +from esrally.storage._transfer import Transfer from esrally.utils.cases import cases URL = "https://rally-tracks.elastic.co/apm/span.json.bz2" @@ -41,16 +42,15 @@ class DummyClient(Client): - def head(self, url: str, ttl: float | None = None) -> Head: + def head(self, url: str, *, cache_ttl: float | None = None) -> Head: return Head(url, content_length=len(DATA), accept_ranges=True, crc32c=CRC32C) - def get(self, url: str, stream: Writable, head: Head | None = None) -> Head: + def get(self, url: str, *, check_head: Head | None = None) -> tuple[Head, Iterator[bytes]]: data = DATA - if head is not None and head.ranges: - data = data[head.ranges.start : head.ranges.end] - if data: - stream.write(data) - return Head(url, ranges=head.ranges, content_length=len(data), document_length=len(DATA), crc32c=CRC32C) + if check_head is not None and check_head.ranges: + data = data[check_head.ranges.start : check_head.ranges.end] + head = Head(url, ranges=check_head.ranges, content_length=len(data), document_length=len(DATA), crc32c=CRC32C) + return head, iter((data,)) @pytest.fixture @@ -69,7 +69,7 @@ class TransferCase: document_length: int | None = len(DATA) crc32c: str = CRC32C multipart_size: int | None = None - max_connections: int = MAX_CONNECTIONS + max_connections: int = StorageConfig.DEFAULT_MAX_CONNECTIONS want_init_error: type[Exception] | None = None want_init_todo: str = "0-1023" want_init_done: str = "" @@ -92,13 +92,12 @@ class TransferCase: ), # It tests multipart working when multipart_size < content_length. multipart_size=TransferCase(multipart_size=128, want_final_done="0-127", want_final_todo="128-1023", want_final_written="0-127"), - # It tests multipart working when multipart_size < content_length. + # It tests mismatching document length. mismatching_document_length=TransferCase( - want_final_done="0-49", + want_final_todo="0-49", want_init_todo="0-49", document_length=50, want_init_document_length=50, - want_final_written="0-49", want_final_document_length=50, want_final_error=RuntimeError, ), diff --git a/tests/types_test.py b/tests/types_test.py index 7305c094d..498333f1b 100644 --- a/tests/types_test.py +++ b/tests/types_test.py @@ -122,5 +122,5 @@ def assert_annotations(obj, ident, *expects): class TestConfigTypeHint: def test_esrally_module_annotations(self): for module in glob_modules(project_root, "esrally/**/*.py"): - assert_annotations(module, "cfg", types.Config, "types.Config", "types.Config | None") + assert_annotations(module, "cfg", types.Config, "types.Config", "types.Config | None", types.Config | None) assert_annotations(module, "config", types.Config, Optional[types.Config], ConfigParser)