diff --git a/.github/workflows/compatibility.yml b/.github/workflows/compatibility.yml index 3e33a07fa..0e5e982b4 100644 --- a/.github/workflows/compatibility.yml +++ b/.github/workflows/compatibility.yml @@ -41,7 +41,7 @@ jobs: exit 1 fi; test: - name: Test (Python ${{ matrix.python-version }}, Redis ${{ matrix.redis-version }}${{ matrix.uvloop == 'True' && ', uvloop' || ''}}${{ matrix.orjson == 'True' && ', orjson' || ''}}${{ matrix.extensions == 'True' && ', compiled' || ''}}${{ matrix.label && format(', {0}', matrix.label) || '' }}) + name: Test (Python ${{ matrix.python-version }}, Anyio ${{ matrix.anyio-backend || 'asyncio' }}, Redis ${{ matrix.redis-version }}${{ matrix.uvloop == 'True' && ', uvloop' || ''}}${{ matrix.orjson == 'True' && ', orjson' || ''}}${{ matrix.extensions == 'True' && ', compiled' || ''}}${{ matrix.label && format(', {0}', matrix.label) || '' }}) runs-on: ubuntu-latest continue-on-error: ${{ matrix.redis-version == 'next' }} strategy: @@ -52,6 +52,7 @@ jobs: test_params: ["-m '(not (dragonfly or valkey or redict))'"] orjson: ["False"] uvloop: ["False"] + anyio-backend: ["asyncio"] runtime_type_checks: ["True"] extensions: ["True"] label: [""] @@ -72,6 +73,12 @@ jobs: extensions: "False" runtime_type_checks: "True" label: "" + - python-version: "3.13" + redis-version: latest + test_params: "-m '(not (dragonfly or valkey or redict))'" + runtime_type_checks: "True" + anyio-backend: "trio" + label: "" - python-version: "3.13" redis-version: latest test_params: "-m '(not (dragonfly or valkey or redict))'" @@ -126,28 +133,22 @@ jobs: - name: Compile extensions if: ${{ matrix.extensions == 'True' }} run: uv run mypyc coredis/constants.py coredis/parser.py coredis/_packer.py coredis/_utils.py - - name: Install uvloop - if: ${{ matrix.uvloop == 'True' }} - run: - uv pip install uvloop - - name: Install orjson - if: ${{ matrix.orjson == 'True' }} - run: - uv pip install orjson - name: Tests with coverage env: COREDIS_UVLOOP: ${{ matrix.uvloop }} + COREDIS_ANYIO_BACKEND: ${{ matrix.anyio-backend }} HOST_OS: linux CI: "True" COREDIS_REDIS_VERSION: ${{matrix.redis-version}} COREDIS_RUNTIME_CHECKS: ${{matrix.runtime_type_checks}} PYTEST_SENTRY_DSN: ${{ matrix.extensions != 'True' && secrets.SENTRY_DSN || ''}} COMPOSE_PARALLEL_LIMIT: 1 + UV_GROUP: ${{ matrix.orjson == 'True' && 'orjson' || 'dev' }} run: | echo "Runtime checks: $COREDIS_RUNTIME_CHECKS" echo "UVLoop: $COREDIS_UVLOOP" echo "CI: $CI" - uv run pytest --reverse --reruns 2 --cov=coredis --cov-report=xml ${{ matrix.test_params }} + uv run --group $UV_GROUP pytest --timeout=60 --reverse --reruns 2 --cov=coredis --cov-report=xml ${{ matrix.test_params }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.2.0 env: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4f82d126b..9c0668242 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -36,7 +36,7 @@ jobs: exit 1 fi; test: - name: Test (Python ${{ matrix.python-version }}, Redis ${{ matrix.redis-version }}${{ matrix.uvloop == 'True' && ', uvloop' || ''}}${{ matrix.orjson == 'True' && ', orjson' || ''}}${{ matrix.extensions == 'True' && ', compiled' || ''}}${{ matrix.label && format(', {0}', matrix.label) || '' }}) + name: Test (Python ${{ matrix.python-version }}, ${{ matrix.anyio-backend || 'asyncio' }}, Redis ${{ matrix.redis-version }}${{ matrix.uvloop == 'True' && ', uvloop' || ''}}${{ matrix.orjson == 'True' && ', orjson' || ''}}${{ matrix.extensions == 'True' && ', compiled' || ''}}${{ matrix.label && format(', {0}', matrix.label) || '' }}) runs-on: ubuntu-latest continue-on-error: ${{ matrix.redis-version == 'next' }} strategy: @@ -47,6 +47,7 @@ jobs: test_params: ["-m '(not (dragonfly or valkey or redict))'"] uvloop: ["False"] orjson: ["False"] + anyio-backend: ["asyncio"] runtime_type_checks: ["True"] extensions: ["True"] label: [""] @@ -66,6 +67,11 @@ jobs: test_params: "-m '(not (dragonfly or valkey or redict))'" runtime_type_checks: "True" uvloop: "True" + - python-version: "3.13" + redis-version: "latest" + test_params: "-m '(not (dragonfly or valkey or redict))'" + runtime_type_checks: "True" + anyio-backend: "trio" - python-version: "3.13" redis-version: "latest" test_params: "-m dragonfly" @@ -102,28 +108,22 @@ jobs: - name: Compile extensions if: ${{ matrix.extensions == 'True' }} run: uv run mypyc coredis/constants.py coredis/parser.py coredis/_packer.py coredis/_utils.py - - name: Install uvloop - if: ${{ matrix.uvloop == 'True' }} - run: - uv pip install uvloop - - name: Install orjson - if: ${{ matrix.orjson == 'True' }} - run: - uv pip install orjson - name: Tests env: COREDIS_UVLOOP: ${{ matrix.uvloop }} + COREDIS_ANYIO_BACKEND: ${{ matrix.anyio-backend }} HOST_OS: linux CI: "True" COREDIS_REDIS_VERSION: ${{matrix.redis-version}} COREDIS_RUNTIME_CHECKS: ${{matrix.runtime_type_checks}} PYTEST_SENTRY_DSN: ${{ matrix.extensions != 'True' && secrets.SENTRY_DSN || ''}} COMPOSE_PARALLEL_LIMIT: 1 + UV_GROUP: ${{ matrix.orjson == 'True' && 'orjson' || 'dev' }} run: | echo "Runtime checks: $COREDIS_RUNTIME_CHECKS" echo "UVLoop: $COREDIS_UVLOOP" echo "CI: $CI" - uv run pytest --reverse --reruns 2 --cov=coredis --cov-report=xml ${{ matrix.test_params }} + uv run --group $UV_GROUP pytest --timeout=60 --reverse --reruns 2 --cov=coredis --cov-report=xml ${{ matrix.test_params }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.2.0 env: diff --git a/HISTORY.rst b/HISTORY.rst index 381e54ab4..15994e3bd 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,34 @@ Changelog ========= +v6.0.0rc1 +--------- +Release Date: TBD + +* Feature + + * Migrates entire library to ``anyio``, adding structured concurrency and Trio support. + + * Almost all objects (clients, connection pools, PubSub, pipelines) now require use of + an async context manager for initialization/cleanup. + * Test suite now runs tests on both asyncio and Trio backends + * Caching is simplified, and users should replace ``TrackingCache`` instances with a + ``LRUCache`` instance instead. Cache no longer has a max byte size, so max keys + should be used instead. + * All connection types use ``anyio`` networking APIs. + * ``Pipeline.execute()`` no longer exists. Instead, pipelines auto-execute when leaving + their context manager. Results can be accessed afterwards in a type-safe way. + * RESP2 support has been dropped. + * All connection pools are now blocking. + * ``Library.wraps`` is now just ``wraps`` and supports callbacks. It also optimistically + calls FCALL in pipelines instead of checking the function exists first. + * When defining type stubs for FFI for Lua scripts or library functions, keys can only + be distinguished from arguments by annotating them with the ``KeyT`` type. + * EVALSHA and FCALL commands now support optional callbacks + * Removes ``Monitor`` wrapper + * Client now includes ``Redis.lock`` as a convenient way to access the ``LuaLock`` + recipe, and the class is now just called ``Lock``. + v5.5.0 ------ Release Date: 2026-01-12 diff --git a/README.md b/README.md index 7e6ad000e..87bc64fe7 100644 --- a/README.md +++ b/README.md @@ -1,156 +1,88 @@ -# coredis - [![docs](https://readthedocs.org/projects/coredis/badge/?version=stable)](https://coredis.readthedocs.org) [![codecov](https://codecov.io/gh/alisaifee/coredis/branch/master/graph/badge.svg)](https://codecov.io/gh/alisaifee/coredis) [![Latest Version in PyPI](https://img.shields.io/pypi/v/coredis.svg)](https://pypi.python.org/pypi/coredis/) [![ci](https://github.com/alisaifee/coredis/actions/workflows/main.yml/badge.svg?branch=master)](https://github.com/alisaifee/coredis/actions?query=branch%3Amaster+workflow%3ACI) [![Supported Python versions](https://img.shields.io/pypi/pyversions/coredis.svg)](https://pypi.python.org/pypi/coredis/) -______________________________________________________________________ - -coredis is an async redis client with support for redis server, cluster & sentinel. - -- The client API uses the specifications in the [Redis command documentation](https://redis.io/commands/) to define the API by using the following conventions: - - - Arguments retain naming from redis as much as possible - - Only optional variadic arguments are mapped to variadic positional or keyword arguments. - When the variable length arguments are not optional (which is almost always the case) the expected argument - is an iterable of type [Parameters](https://coredis.readthedocs.io/en/latest/api/typing.html#coredis.typing.Parameters) or `Mapping`. - - Pure tokens used as flags are mapped to boolean arguments - - `One of` arguments accepting pure tokens are collapsed and accept a [PureToken](https://coredis.readthedocs.io/en/latest/api/utilities.html#coredis.tokens.PureToken) - -- Responses are mapped between RESP and python types as closely as possible. - -- For higher level concepts such as Pipelines, LUA Scripts, PubSub & Streams - abstractions are provided to encapsulate recommended patterns. - See the [Handbook](https://coredis.readthedocs.io/en/latest/handbook/index.html) - and the [API Documentation](https://coredis.readthedocs.io/en/latest/api/index.html) - for more details. - -______________________________________________________________________ +# coredis - +Fast, async, fully-typed Redis client with support for cluster and sentinel -- [Installation](#installation) -- [Feature Summary](#feature-summary) - - [Deployment topologies](#deployment-topologies) - - [Application patterns](#application-patterns) - - [Server side scripting](#server-side-scripting) - - [Redis Modules](#redis-modules) - - [Miscellaneous](#miscellaneous) -- [Quick start](#quick-start) - - [Single Node or Cluster client](#single-node-or-cluster-client) - - [Sentinel](#sentinel) -- [Compatibility](#compatibility) - - [Supported python versions](#supported-python-versions) - - [Redis API compatible databases backends](#redis-api-compatible-databases) -- [References](#references) +## Features - +- Fully typed, even when using pipelines, Lua scripts, and libraries +- Redis [Cluster](https://coredis.readthedocs.org/en/latest/handbook/cluster.html#redis-cluster) and [Sentinel](https://coredis.readthedocs.org/en/latest/api/clients.html#sentinel) support +- Built with structured concurrency on `anyio`, supports both `asyncio` and `trio` +- Server-assisted [client-side caching](https://coredis.readthedocs.org/en/latest/handbook/caching.html) implementation +- [Redis Stack modules](https://coredis.readthedocs.org/en/latest/handbook/modules.html) support +- [Redis PubSub](https://coredis.readthedocs.org/en/latest/handbook/pubsub.html) +- [Pipelining](https://coredis.readthedocs.org/en/latest/handbook/pipelines.html) +- [Lua scripts](https://coredis.readthedocs.org/en/latest/handbook/scripting.html#lua_scripting) and [Redis functions](https://coredis.readthedocs.org/en/latest/handbook/scripting.html#library-functions) \[`>= Redis 7.0`\] support, with optional types +- Convenient [Stream Consumers](https://coredis.readthedocs.org/en/latest/handbook/streams.html) implementation +- Comprehensive documentation +- Optional [runtime type validation](https://coredis.readthedocs.org/en/latest/handbook/typing.html#runtime-type-checking) (via [beartype](https://github.com/beartype/beartype)) ## Installation -To install coredis: - -```bash +```console $ pip install coredis ``` -## Feature Summary - -### Deployment topologies - -- [Redis Cluster](https://coredis.readthedocs.org/en/latest/handbook/cluster.html#redis-cluster) -- [Sentinel](https://coredis.readthedocs.org/en/latest/api/clients.html#sentinel) - -### Application patterns - -- [Connection Pooling](https://coredis.readthedocs.org/en/latest/handbook/connections.html#connection-pools) -- [PubSub](https://coredis.readthedocs.org/en/latest/handbook/pubsub.html) -- [Sharded PubSub](https://coredis.readthedocs.org/en/latest/handbook/pubsub.html#sharded-pub-sub) \[`>= Redis 7.0`\] -- [Stream Consumers](https://coredis.readthedocs.org/en/latest/handbook/streams.html) -- [Pipelining](https://coredis.readthedocs.org/en/latest/handbook/pipelines.html) -- [Client side caching](https://coredis.readthedocs.org/en/latest/handbook/caching.html) - -### Server side scripting - -- [LUA Scripting](https://coredis.readthedocs.org/en/latest/handbook/scripting.html#lua_scripting) -- [Redis Libraries and functions](https://coredis.readthedocs.org/en/latest/handbook/scripting.html#library-functions) \[`>= Redis 7.0`\] - -### Redis Modules - -- [RedisJSON](https://coredis.readthedocs.org/en/latest/handbook/modules.html#redisjson) -- [RediSearch](https://coredis.readthedocs.org/en/latest/handbook/modules.html#redisearch) -- [RedisBloom](https://coredis.readthedocs.org/en/latest/handbook/modules.html#redisbloom) -- [RedisTimeSeries](https://coredis.readthedocs.org/en/latest/handbook/modules.html#redistimeseries) - -### Miscellaneous - -- Public API annotated with type annotations -- Optional [Runtime Type Validation](https://coredis.readthedocs.org/en/latest/handbook/typing.html#runtime-type-checking) (via [beartype](https://github.com/beartype/beartype)) - -## Quick start +## Getting started -### Single Node or Cluster client +To start, you'll need to connect to your `Redis` instance: ```python -import asyncio -from coredis import Redis, RedisCluster +import trio +from coredis import Redis -async def example(): - client = Redis(host='127.0.0.1', port=6379, db=0) - # or with redis cluster - # client = RedisCluster(startup_nodes=[{"host": "127.0.01", "port": 7001}]) +client = Redis(host='127.0.0.1', port=6379, db=0, decode_responses=True) +async with client: await client.flushdb() await client.set('foo', 1) assert await client.exists(['foo']) == 1 assert await client.incr('foo') == 2 assert await client.incrby('foo', increment=100) == 102 - assert int(await client.get('foo')) == 102 + assert int(await client.get('foo') or 0) == 102 assert await client.expire('foo', 1) - await asyncio.sleep(0.1) + await trio.sleep(0.1) assert await client.ttl('foo') == 1 assert await client.pttl('foo') < 1000 - await asyncio.sleep(1) + await trio.sleep(1) assert not await client.exists(['foo']) - -asyncio.run(example()) ``` -### Sentinel +Sentinel is also supported: ```python -import asyncio from coredis.sentinel import Sentinel -async def example(): - sentinel = Sentinel(sentinels=[("localhost", 26379)]) +sentinel = Sentinel(sentinels=[("localhost", 26379)]) +async with sentinel: primary = sentinel.primary_for("myservice") replica = sentinel.replica_for("myservice") - assert await primary.set("fubar", 1) - assert int(await replica.get("fubar")) == 1 - -asyncio.run(example()) + async with primary, replica: + assert await primary.set("fubar", 1) + assert int(await replica.get("fubar")) == 1 ``` -To see a full list of supported redis commands refer to the [Command -compatibility](https://coredis.readthedocs.io/en/latest/compatibility.html) -documentation - -Details about supported Redis modules and their commands can be found -[here](https://coredis.readthedocs.io/en/latest/handbook/modules.html) - ## Compatibility +To see a full list of supported Redis commands refer to the [Command +compatibility](https://coredis.readthedocs.io/en/latest/compatibility.html) +documentation. Details about supported Redis modules and their commands can be found +[here](https://coredis.readthedocs.io/en/latest/handbook/modules.html). + coredis is tested against redis versions >= `7.0` The test matrix status can be reviewed [here](https://github.com/alisaifee/coredis/actions/workflows/main.yml) coredis is additionally tested against: -- ` uvloop >= 0.15.0` +- `uvloop >= 0.15.0` +- `trio` ### Supported python versions diff --git a/coredis/__init__.py b/coredis/__init__.py index 42819cd88..d4bf14931 100644 --- a/coredis/__init__.py +++ b/coredis/__init__.py @@ -17,12 +17,8 @@ Connection, UnixDomainSocketConnection, ) -from coredis.pool import ( - BlockingClusterConnectionPool, - BlockingConnectionPool, - ClusterConnectionPool, - ConnectionPool, -) +from coredis.pool import ClusterConnectionPool, ConnectionPool +from coredis.sentinel import Sentinel from coredis.tokens import PureToken __all__ = [ @@ -33,10 +29,9 @@ "Connection", "UnixDomainSocketConnection", "ClusterConnection", - "BlockingConnectionPool", "ConnectionPool", - "BlockingClusterConnectionPool", "ClusterConnectionPool", "PureToken", + "Sentinel", "__version__", ] diff --git a/coredis/_concurrency.py b/coredis/_concurrency.py new file mode 100644 index 000000000..f36dff482 --- /dev/null +++ b/coredis/_concurrency.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from collections import deque +from typing import Any, Awaitable, Generic, TypeVar, overload + +from anyio import Event, Lock, create_task_group + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") + + +@overload +async def gather( + awaitable1: Awaitable[T1], + awaitable2: Awaitable[T2], + /, + *, + return_exceptions: bool = False, +) -> tuple[T1, T2]: ... + + +@overload +async def gather( + awaitable1: Awaitable[T1], + awaitable2: Awaitable[T2], + awaitable3: Awaitable[T3], + /, + *, + return_exceptions: bool = False, +) -> tuple[T1, T2, T3]: ... + + +@overload +async def gather( + awaitable1: Awaitable[T1], + awaitable2: Awaitable[T2], + awaitable3: Awaitable[T3], + awaitable4: Awaitable[T4], + /, + *, + return_exceptions: bool = False, +) -> tuple[T1, T2, T3, T4]: ... + + +@overload +async def gather( + awaitable1: Awaitable[T1], + awaitable2: Awaitable[T2], + awaitable3: Awaitable[T3], + awaitable4: Awaitable[T4], + awaitable5: Awaitable[T5], + /, + *, + return_exceptions: bool = False, +) -> tuple[T1, T2, T3, T4, T5]: ... + + +@overload +async def gather( + awaitable1: Awaitable[T1], + awaitable2: Awaitable[T2], + awaitable3: Awaitable[T3], + awaitable4: Awaitable[T4], + awaitable5: Awaitable[T5], + awaitable6: Awaitable[T6], + /, + *, + return_exceptions: bool = False, +) -> tuple[T1, T2, T3, T4, T5, T6]: ... + + +@overload +async def gather( + *awaitables: Awaitable[T1], + return_exceptions: bool = False, +) -> tuple[T1, ...]: ... + + +async def gather(*awaitables: Awaitable[Any], return_exceptions: bool = False) -> tuple[Any, ...]: + if not awaitables: + return () + results: list[Any] = [None] * len(awaitables) + + async def runner(awaitable: Awaitable[Any], i: int) -> None: + try: + results[i] = await awaitable + except Exception as exc: + if not return_exceptions: + raise + results[i] = exc + + async with create_task_group() as tg: + for i, awaitable in enumerate(awaitables): + tg.start_soon(runner, awaitable, i) + + return tuple(results) + + +class QueueEmpty(Exception): ... + + +class QueueFull(Exception): ... + + +class Queue(Generic[T1]): + def __init__(self, maxsize: int = 0): + self._maxsize = maxsize + self._queue: deque[T1 | None] = deque( + [None for _ in range(self._maxsize)], maxlen=self._maxsize + ) + self._getters: deque[Event] = deque() + self._putters: deque[Event] = deque() + self._lock = Lock() + + def empty(self) -> bool: + return not self._queue + + def full(self) -> bool: + return self._maxsize > 0 and len(self._queue) >= self._maxsize + + async def put(self, item: T1) -> None: + async with self._lock: + while self.full(): + ev = Event() + self._putters.append(ev) + await ev.wait() + self._queue.append(item) + if self._getters: + self._getters.popleft().set() + + def put_nowait(self, item: T1) -> None: + if self.full(): + raise QueueFull() + self._queue.append(item) + if self._getters: + ev = self._getters.popleft() + ev.set() + + async def get(self) -> T1 | None: + async with self._lock: + while self.empty(): + ev = Event() + self._getters.append(ev) + await ev.wait() + item = self._queue.pop() + if self._putters and not self.full(): + self._putters.popleft().set() + + return item + + def get_nowait(self) -> T1 | None: + if self.empty(): + raise QueueEmpty() + item = self._queue.pop() + if self._putters and not self.full(): + self._putters.popleft().set() + + return item diff --git a/coredis/_protocols.py b/coredis/_protocols.py index 7fec70e25..18d3ec95c 100644 --- a/coredis/_protocols.py +++ b/coredis/_protocols.py @@ -1,7 +1,6 @@ from __future__ import annotations -import asyncio - +from anyio.streams.memory import MemoryObjectSendStream from typing_extensions import runtime_checkable from coredis.response._callbacks import NoopCallback @@ -47,4 +46,4 @@ def create_request( class ConnectionP(Protocol): decode_responses: bool encoding: str - push_messages: asyncio.Queue[ResponseType] + push_messages: MemoryObjectSendStream[ResponseType] diff --git a/coredis/_sidecar.py b/coredis/_sidecar.py deleted file mode 100644 index c404e1835..000000000 --- a/coredis/_sidecar.py +++ /dev/null @@ -1,114 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -import weakref -from typing import TYPE_CHECKING, Any - -from coredis.connection import BaseConnection, Connection -from coredis.exceptions import ConnectionError -from coredis.typing import ResponseType, TypeVar - -if TYPE_CHECKING: - import coredis.client - -SidecarT = TypeVar("SidecarT", bound="Sidecar") - - -class Sidecar: - """ - A sidecar to a redis client that reserves a single connection - and moves any responses from the socket to a FIFO queue - """ - - def __init__( - self, push_message_types: set[bytes], health_check_interval_seconds: int = 5 - ) -> None: - self._client: weakref.ReferenceType[coredis.client.Client[Any]] | None = None - self.messages: asyncio.Queue[ResponseType] = asyncio.Queue() - self.connection: Connection | None = None - self.client_id: int | None = None - self.read_task: asyncio.Task[None] | None = None - self.push_message_types = push_message_types - self.health_check_interval = health_check_interval_seconds - self.health_check_task: asyncio.Task[None] | None = None - self.last_checkin: float = 0 - - @property - def client(self) -> coredis.client.Client[Any] | None: - if self._client: - return self._client() - return None # noqa - - async def start(self: SidecarT, client: coredis.client.Client[Any]) -> SidecarT: - self._client = weakref.ref(client, lambda *_: self.stop()) - if not self.connection and self.client: - self.connection = await self.client.connection_pool.get_connection() - self.connection.register_connect_callback(self.on_reconnect) - await self.connection.connect() - if self.connection.tracking_client_id: # noqa - await self.connection.update_tracking_client(False) - if not self.read_task or self.read_task.done(): - self.read_task = asyncio.create_task(self.__read_loop()) - if not self.health_check_task or self.health_check_task.done(): - self.health_check_task = asyncio.create_task(self.__health_check()) - return self - - def process_message(self, message: ResponseType) -> tuple[ResponseType, ...]: - return (message,) # noqa - - def stop(self) -> None: - try: - asyncio.get_running_loop() - if self.read_task and not self.read_task.done(): - self.read_task.cancel() - if self.health_check_task and not self.health_check_task.done(): - self.health_check_task.cancel() - except RuntimeError: - pass - if self.connection: - self.connection.disconnect() - if self.client and self.connection: # noqa - self.client.connection_pool.release(self.connection) - self.connection = None - self.client_id = None - - def __del__(self) -> None: - self.stop() - - async def on_reconnect(self, connection: BaseConnection) -> None: - self.client_id = connection.client_id - self.last_checkin = time.monotonic() - - async def __health_check(self) -> None: - while True: - try: - if self.connection: - await self.connection.send_command(b"PING") - await asyncio.sleep(self.health_check_interval) - except asyncio.CancelledError: - break - - async def __read_loop(self) -> None: - while self.connection: - try: - response = await self.connection.fetch_push_message( - decode=False, push_message_types=self.push_message_types - ) - self.last_checkin = time.monotonic() - if response == b"PONG" or b"pong" in response: # type: ignore - continue - for m in self.process_message(response): - self.messages.put_nowait(m) - except asyncio.CancelledError: - break - except ConnectionError: - if self.client and self.connection: - self.client.connection_pool.release(self.connection) - self.connection = None - - if self.client: - asyncio.get_running_loop().call_soon( - asyncio.create_task, self.start(self.client) - ) - break diff --git a/coredis/_utils.py b/coredis/_utils.py index 86b8507d6..46720b8e9 100644 --- a/coredis/_utils.py +++ b/coredis/_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from collections import UserDict from typing import Any @@ -13,6 +14,9 @@ TypeVar, ) +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + T = TypeVar("T") U = TypeVar("U") @@ -138,7 +142,7 @@ def make_hashable(*args: Any) -> tuple[Hashable, ...]: ) -def query_param_to_bool(value: Any | None) -> bool | None: +def query_param_to_bool(value: Any) -> bool | None: if value is None or value in ("", b""): return None if isinstance(value, (int, float, bool, str, bytes)): diff --git a/coredis/cache.py b/coredis/cache.py index 230706f5a..437b9cb9b 100644 --- a/coredis/cache.py +++ b/coredis/cache.py @@ -1,36 +1,32 @@ from __future__ import annotations -import asyncio import dataclasses -import time -import weakref from abc import ABC, abstractmethod from collections import Counter -from typing import TYPE_CHECKING, Any - -from coredis._sidecar import Sidecar -from coredis._utils import b, make_hashable -from coredis.commands import PubSub -from coredis.connection import BaseConnection +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, cast + +from anyio import ( + TASK_STATUS_IGNORED, + create_task_group, + current_time, + sleep, +) +from anyio.abc import TaskStatus +from exceptiongroup import catch + +from coredis._utils import b, logger, make_hashable +from coredis.commands.constants import CommandName +from coredis.exceptions import RETRYABLE +from coredis.pool.basic import ConnectionPool +from coredis.pool.cluster import ClusterConnectionPool from coredis.typing import ( - Generic, - Hashable, - Literal, - ModuleType, OrderedDict, RedisValueT, ResponseType, - TypeVar, + StringT, ) -asizeof: ModuleType | None = None - -try: - from pympler import asizeof -except (AttributeError, KeyError): - # Not available in pypy - pass - if TYPE_CHECKING: import coredis.client @@ -114,24 +110,6 @@ class AbstractCache(ABC): :class:`coredis.Redis` or :class:`coredis.RedisCluster` """ - @abstractmethod - async def initialize( - self, - client: coredis.client.Redis[Any] | coredis.client.RedisCluster[Any], - ) -> AbstractCache: - """ - Associate and initialize this cache with the provided client - """ - ... - - @property - @abstractmethod - def healthy(self) -> bool: - """ - Whether the cache is healthy and should be taken seriously - """ - ... - @abstractmethod def get(self, command: bytes, key: RedisValueT, *args: RedisValueT) -> ResponseType: """ @@ -155,6 +133,13 @@ def invalidate(self, *keys: RedisValueT) -> None: """ ... + @abstractmethod + def reset(self) -> None: + """ + Reset the cache + """ + ... + @property @abstractmethod def stats(self) -> CacheStats: @@ -181,621 +166,238 @@ def feedback(self, command: bytes, key: RedisValueT, *args: RedisValueT, match: """ ... - @abstractmethod - def get_client_id(self, connection: BaseConnection) -> int | None: - """ - If the cache supports receiving invalidation events from the server - return the ``client_id`` that the :paramref:`connection` should send - redirects to. - """ - ... - - @abstractmethod - def reset(self) -> None: - """ - Reset the cache - """ - ... - - @abstractmethod - def shutdown(self) -> None: - """ - Explicitly shutdown the cache - """ - ... - - -ET = TypeVar("ET") - - -class LRUCache(Generic[ET]): - def __init__(self, max_items: int = -1, max_bytes: int = -1): - self.max_items = max_items - self.max_bytes = max_bytes - self.__cache: OrderedDict[Hashable, ET] = OrderedDict() - - if self.max_bytes > 0 and asizeof is not None: - self.max_bytes += asizeof.asizeof(self.__cache) - elif self.max_bytes > 0: - raise RuntimeError("max_bytes not supported as dependency pympler not available") - - def get(self, key: Hashable) -> ET: - if key not in self.__cache: - raise KeyError(key) - self.__cache.move_to_end(key) - - return self.__cache[key] - - def insert(self, key: Hashable, value: ET) -> None: - self.__check_capacity() - self.__cache[key] = value - self.__cache.move_to_end(key) - - def setdefault(self, key: Hashable, value: ET) -> ET: - try: - self.__check_capacity() - - return self.get(key) - except KeyError: - self.insert(key, value) - - return self.get(key) - - def remove(self, key: Hashable) -> None: - if key in self.__cache: - self.__cache.pop(key) - - def clear(self) -> None: - self.__cache.clear() - - def popitem(self) -> tuple[Any, Any] | None: - """ - Recursively remove the oldest entry. If - the oldest entry is another LRUCache trigger - the removal of its oldest entry and if that - turns out to be an empty LRUCache, remove that. - """ - try: - oldest = next(iter(self.__cache)) - item = self.__cache[oldest] - except StopIteration: - return None - - if isinstance(item, LRUCache): - if popped := item.popitem(): - return popped - if entry := self.__cache.popitem(last=False): - return entry - return None - - def shrink(self) -> None: - """ - Remove old entries until the size of the cache - is less than :paramref:`LRUCache.max_bytes` or if - there is nothing left to remove. - """ - - if self.max_bytes > 0 and asizeof is not None: - cur_size = asizeof.asizeof(self.__cache) - while cur_size > self.max_bytes: - if (popped := self.popitem()) is None: - return - cur_size -= asizeof.asizeof(popped[0]) + asizeof.asizeof(popped[1]) - - def __repr__(self) -> str: - if asizeof is not None: - return ( - f"LruCache" - ) - else: - return f"LruCache None: - if len(self.__cache) == self.max_items: - self.__cache.popitem(last=False) - - -class NodeTrackingCache( - Sidecar, - AbstractCache, -): - """ - An LRU cache that uses server assisted client caching - to ensure local cache entries are invalidated if any - operations are performed on the keys by another client. - """ +class LRUCache(AbstractCache): def __init__( self, max_keys: int = 2**12, - max_size_bytes: int = 64 * 1024 * 1024, - max_idle_seconds: int = 5, confidence: float = 100, dynamic_confidence: bool = False, - cache: LRUCache[LRUCache[LRUCache[ResponseType]]] | None = None, - stats: CacheStats | None = None, ) -> None: - """ - :param max_keys: maximum keys to cache. A negative value represents - and unbounded cache. - :param max_size_bytes: maximum size in bytes for the local cache. - A negative value represents an unbounded cache. - :param max_idle_seconds: maximum duration to tolerate no updates - from the server. When the duration is exceeded the connection - and cache will be reset. - :param confidence: 0 - 100. Lower values will result in the client - discarding and / or validating the cached responses - :param dynamic_confidence: Whether to adjust the confidence based on - sampled validations. Tainted values drop the confidence by 0.1% and - confirmations of correct cached values will increase the confidence by 0.01% - upto 100. - """ - super().__init__({b"invalidate"}, max(1, max_idle_seconds - 1)) - self.__protocol_version: Literal[2, 3] | None = None - self.__invalidation_task: asyncio.Task[None] | None = None - self.__compact_task: asyncio.Task[None] | None = None - self.__max_idle_seconds = max_idle_seconds - self.__confidence = self.__original_confidence = confidence - self.__dynamic_confidence = dynamic_confidence - self.__stats = stats or CacheStats() - self.__cache: LRUCache[LRUCache[LRUCache[ResponseType]]] = cache or LRUCache( - max_keys, max_size_bytes - ) + self._confidence = self._original_confidence = confidence + self._dynamic_confidence = dynamic_confidence + self._stats = CacheStats() + self.max_keys = max_keys + # key -> (command, args) -> response + self._storage: OrderedDict[bytes, dict[tuple[bytes, Any], ResponseType]] = OrderedDict() - @property - def healthy(self) -> bool: - return bool( - self.connection - and self.connection.is_connected - and time.monotonic() - self.last_checkin < self.__max_idle_seconds - ) + def put( + self, command: bytes, key: RedisValueT, *args: RedisValueT, value: ResponseType + ) -> None: + key_bytes = b(key) + composite_key = (command, make_hashable(*args)) - @property - def confidence(self) -> float: - return self.__confidence + if key_bytes not in self._storage and len(self._storage) >= self.max_keys: + if self._storage: + self._storage.popitem(last=False) - @property - def stats(self) -> CacheStats: - return self.__stats + # Get or create the key's cache dict + if key_bytes not in self._storage: + self._storage[key_bytes] = {} + + self._storage[key_bytes][composite_key] = value + self._storage.move_to_end(key_bytes) def get(self, command: bytes, key: RedisValueT, *args: RedisValueT) -> ResponseType: - try: - cached = self.__cache.get(b(key)).get(command).get(make_hashable(*args)) - self.__stats.hit(key) + key_bytes = b(key) + if key_bytes not in self._storage: + self._stats.miss(key) + raise KeyError(key) - return cached - except KeyError: - self.__stats.miss(key) - raise + # Move to end for LRU + self._storage.move_to_end(key_bytes) + composite_key = (command, make_hashable(*args)) + if composite_key not in self._storage[key_bytes]: + self._stats.miss(key) + raise KeyError(key) - def put( - self, command: bytes, key: RedisValueT, *args: RedisValueT, value: ResponseType - ) -> None: - self.__cache.setdefault(b(key), LRUCache()).setdefault(command, LRUCache()).insert( - make_hashable(*args), value - ) + self._stats.hit(key) + return self._storage[key_bytes][composite_key] def invalidate(self, *keys: RedisValueT) -> None: for key in keys: - self.__stats.invalidate(key) - self.__cache.remove(b(key)) + self._stats.invalidate(key) + self._storage.pop(b(key), None) + + def reset(self) -> None: + self._storage.clear() + self._stats.compact() + self._confidence = self._original_confidence + + @property + def stats(self) -> CacheStats: + return self._stats + + @property + def confidence(self) -> float: + return self._confidence def feedback(self, command: bytes, key: RedisValueT, *args: RedisValueT, match: bool) -> None: if not match: - self.__stats.mark_dirty(key) + self._stats.mark_dirty(key) self.invalidate(key) - if self.__dynamic_confidence: - self.__confidence = min( + if self._dynamic_confidence: + self._confidence = min( 100.0, - max(0.0, self.__confidence * (1.0001 if match else 0.999)), + max(0.0, self._confidence * (1.0001 if match else 0.999)), ) - def reset(self) -> None: - self.__cache.clear() - self.__stats.compact() - self.__confidence = self.__original_confidence - - def process_message(self, message: ResponseType) -> tuple[ResponseType, ...]: - assert isinstance(message, list) - - if self.__protocol_version == 2: - assert isinstance(message[0], bytes) - - if b(message[0]) in PubSub.SUBUNSUB_MESSAGE_TYPES: - return () - elif message[2] is not None: - assert isinstance(message[2], list) - - return tuple(k for k in message[2]) - elif message[1] is not None: - assert isinstance(message[1], list) - - return tuple(k for k in message[1]) - - return () # noqa - - async def initialize( - self, - client: coredis.client.Redis[Any] | coredis.client.RedisCluster[Any], - ) -> NodeTrackingCache: - self.__protocol_version = client.protocol_version - await super().start(client) - - if not self.__invalidation_task or self.__invalidation_task.done(): - self.__invalidation_task = asyncio.create_task(self.__invalidate()) - - if not self.__compact_task or self.__compact_task.done(): - self.__compact_task = asyncio.create_task(self.__compact()) - - return self - async def on_reconnect(self, connection: BaseConnection) -> None: - self.__cache.clear() - await super().on_reconnect(connection) - - if self.__protocol_version == 2 and self.connection: - await self.connection.send_command(b"SUBSCRIBE", b"__redis__:invalidate") - - def shutdown(self) -> None: - try: - asyncio.get_running_loop() - - if self.__invalidation_task: - self.__invalidation_task.cancel() - - if self.__compact_task: - self.__compact_task.cancel() - super().stop() - except RuntimeError: - pass - - def get_client_id(self, client: BaseConnection) -> int | None: - if self.connection and self.connection.is_connected: - return self.client_id - - return None - - async def __compact(self) -> None: - while True: - try: - self.__cache.shrink() - self.__stats.compact() - await asyncio.sleep(max(1, self.__max_idle_seconds - 1)) - except asyncio.CancelledError: - break - - async def __invalidate(self) -> None: - while True: - try: - key = b(await self.messages.get()) - self.invalidate(key) - self.messages.task_done() - except asyncio.CancelledError: - break - except RuntimeError: # noqa - break - - -class ClusterTrackingCache(AbstractCache): +class TrackingCache(AbstractCache): """ - An LRU cache for redis cluster that uses server assisted client caching - to ensure local cache entries are invalidated if any operations are performed - on the keys by another client. - - The cache maintains an additional connection per node (including replicas) - in the cluster to listen to invalidation events + Abstract layout of a tracking cache to be used internally + by coredis clients (Redis/RedisCluster) """ - def __init__( - self, - max_keys: int = 2**12, - max_size_bytes: int = 64 * 1024 * 1024, - max_idle_seconds: int = 5, - confidence: float = 100, - dynamic_confidence: bool = False, - cache: LRUCache[LRUCache[LRUCache[ResponseType]]] | None = None, - stats: CacheStats | None = None, - ) -> None: - """ - :param max_keys: maximum keys to cache. A negative value represents - and unbounded cache. - :param max_size_bytes: maximum size in bytes for the local cache. - A negative value represents an unbounded cache. - :param max_idle_seconds: maximum duration to tolerate no updates - from the server. When the duration is exceeded the connection - and cache will be reset. - :param confidence: 0 - 100. Lower values will result in the client - discarding and / or validating the cached responses - :param dynamic_confidence: Whether to adjust the confidence based on - sampled validations. Tainted values drop the confidence by 0.1% and - confirmations of correct cached values will increase the confidence by 0.01% - upto 100. - """ - self.node_caches: dict[str, NodeTrackingCache] = {} - self.__protocol_version: Literal[2, 3] | None = None - self.__cache: LRUCache[LRUCache[LRUCache[ResponseType]]] = cache or LRUCache( - max_keys, max_size_bytes - ) - self.__nodes: list[coredis.client.Redis[Any]] = [] - self.__max_idle_seconds = max_idle_seconds - self.__confidence = self.__original_confidence = confidence - self.__dynamic_confidence = dynamic_confidence - self.__stats = stats or CacheStats() - self.__client: weakref.ReferenceType[coredis.client.RedisCluster[Any]] | None = None - - async def initialize( - self, - client: coredis.client.Redis[Any] | coredis.client.RedisCluster[Any], - ) -> ClusterTrackingCache: - import coredis.client - - assert isinstance(client, coredis.client.RedisCluster) - - self.__client = weakref.ref(client) - self.__cache.clear() - - for sidecar in self.node_caches.values(): - sidecar.shutdown() - self.node_caches.clear() - self.__nodes = list(client.all_nodes) - - for node in self.__nodes: - node_cache = NodeTrackingCache( - max_idle_seconds=self.__max_idle_seconds, - confidence=self.__confidence, - dynamic_confidence=self.__dynamic_confidence, - cache=self.__cache, - stats=self.__stats, - ) - await node_cache.initialize(node) - assert node_cache.connection - self.node_caches[node_cache.connection.location] = node_cache - - return self - - @property - def client(self) -> coredis.client.RedisCluster[Any] | None: - if self.__client: - return self.__client() - - return None # noqa + _cache: AbstractCache - @property - def healthy(self) -> bool: - return bool( - self.client - and self.client.connection_pool.initialized - and self.node_caches - and all(cache.healthy for cache in self.node_caches.values()) - ) - - @property - def confidence(self) -> float: - return self.__confidence - - @property - def stats(self) -> CacheStats: - return self.__stats + @abstractmethod + async def run( + self, pool: ConnectionPool, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED + ) -> None: + pass - def get_client_id(self, connection: BaseConnection) -> int | None: - try: - return self.node_caches[connection.location].get_client_id(connection) - except KeyError: - return None + @abstractmethod + def get_client_id( + self, + connection: coredis.connection.BaseConnection, + ) -> int | None: + pass def get(self, command: bytes, key: RedisValueT, *args: RedisValueT) -> ResponseType: - try: - cached = self.__cache.get(b(key)).get(command).get(make_hashable(*args)) - self.__stats.hit(key) - - return cached - except KeyError: - self.__stats.miss(key) - raise + return self._cache.get(command, key, *args) def put( self, command: bytes, key: RedisValueT, *args: RedisValueT, value: ResponseType ) -> None: - self.__cache.setdefault(b(key), LRUCache()).setdefault(command, LRUCache()).insert( - make_hashable(*args), value - ) + self._cache.put(command, key, *args, value=value) def invalidate(self, *keys: RedisValueT) -> None: - for key in keys: - self.__stats.invalidate(key) - self.__cache.remove(b(key)) - - def feedback(self, command: bytes, key: RedisValueT, *args: RedisValueT, match: bool) -> None: - if not match: - self.__stats.mark_dirty(key) - self.invalidate(key) - - if self.__dynamic_confidence: - self.__confidence = min( - 100.0, - max(0.0, self.__confidence * (1.0001 if match else 0.999)), - ) + self._cache.invalidate(*keys) def reset(self) -> None: - self.__cache.clear() - self.__stats.compact() - self.__confidence = self.__original_confidence + self._cache.reset() - def shutdown(self) -> None: - if self.node_caches: - for sidecar in self.node_caches.values(): - sidecar.shutdown() - self.node_caches.clear() - self.__nodes.clear() + @property + def stats(self) -> CacheStats: + return self._cache.stats + + @property + def confidence(self) -> float: + return self._cache.confidence - def __del__(self) -> None: - self.shutdown() + def feedback(self, command: bytes, key: RedisValueT, *args: RedisValueT, match: bool) -> None: + self._cache.feedback(command, key, *args, match=match) -class TrackingCache(AbstractCache): +class NodeTrackingCache(TrackingCache): """ - An LRU cache that uses server assisted client caching to ensure local cache entries - are invalidated if any operations are performed on the keys by another client. - - This class proxies to either :class:`~coredis.cache.NodeTrackingCache` - or :class:`~coredis.cache.ClusterTrackingCache` depending on which type of client - it is passed into. + Wraps an AbstractCache instance to use server assisted client caching + to ensure local cache entries are invalidated if any operations are + performed on the keys by another client. """ - def __init__( + def __init__(self, cache: AbstractCache | None = None) -> None: + """ + :param cache: AbstractCache instance to wrap + :param compact_interval_seconds: frequency to check if cache is too big and shrink it + """ + self._cache = cache or LRUCache() + self.client_id: int | None = None + + def get_client_id( self, - max_keys: int = 2**12, - max_size_bytes: int = 64 * 1024 * 1024, - max_idle_seconds: int = 5, - confidence: float = 100.0, - dynamic_confidence: bool = False, - cache: LRUCache[LRUCache[LRUCache[ResponseType]]] | None = None, - stats: CacheStats | None = None, + connection: coredis.connection.BaseConnection, + ) -> int | None: + return self.client_id + + async def run( + self, pool: ConnectionPool, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED ) -> None: """ - :param max_keys: maximum keys to cache. A negative value represents - and unbounded cache. - :param max_size_bytes: maximum size in bytes for the local cache. - A negative value represents an unbounded cache. - :param max_idle_seconds: maximum duration to tolerate no updates - from the server. When the duration is exceeded the connection - and cache will be reset. - :param confidence: 0 - 100. Lower values will result in the client - discarding and / or validating the cached responses - :param dynamic_confidence: Whether to adjust the confidence based on - sampled validations. Tainted values drop the confidence by 0.1% and - confirmations of correct cached values will increase the confidence by 0.01% - upto 100. + Run a single connection that listens for invalidation messages, + with reconnection logic. """ - self.instance: ClusterTrackingCache | NodeTrackingCache | None = None - self.__max_keys = max_keys - self.__max_size_bytes = max_size_bytes - self.__max_idle_seconds = max_idle_seconds - self.__confidence = confidence - self.__dynamic_confidence = dynamic_confidence - self.__cache: LRUCache[LRUCache[LRUCache[ResponseType]]] = cache or LRUCache( - max_keys, max_size_bytes - ) - self.__client: ( - None - | (weakref.ReferenceType[coredis.client.Redis[Any] | coredis.client.RedisCluster[Any],]) - ) = None - self.__stats = stats or CacheStats() + start_time, started, tries = current_time(), False, 0 - async def initialize( - self, - client: coredis.client.Redis[Any] | coredis.client.RedisCluster[Any], - ) -> TrackingCache: - import coredis.client - - if self.__client and self.__client() != client: - copy = self.share() - - return await copy.initialize(client) - - self.__client = weakref.ref(client) - - if not self.instance: - if isinstance(client, coredis.client.RedisCluster): - self.instance = ClusterTrackingCache( - self.__max_keys, - self.__max_size_bytes, - self.__max_idle_seconds, - confidence=self.__confidence, - dynamic_confidence=self.__dynamic_confidence, - cache=self.__cache, - stats=self.__stats, - ) + def handle_error(*args: Any) -> None: + nonlocal tries, start_time + if current_time() - start_time > 10: + tries = 0 else: - self.instance = NodeTrackingCache( - self.__max_keys, - self.__max_size_bytes, - self.__max_idle_seconds, - confidence=self.__confidence, - dynamic_confidence=self.__dynamic_confidence, - cache=self.__cache, - stats=self.__stats, - ) - await self.instance.initialize(client) - - return self + tries += 1 + logger.warning("Cache connection lost, retrying...") - @property - def healthy(self) -> bool: - return bool(self.instance and self.instance.healthy) + while True: + # retry with exponential backoff + await sleep(min(tries**2, 300)) + with catch({RETRYABLE: handle_error}): + async with pool.acquire() as self._connection: + if self._connection.tracking_client_id: + await self._connection.update_tracking_client(False) + self.client_id = self._connection.client_id + start_time = current_time() + async with create_task_group() as self._tg: + self._tg.start_soon(self._consumer) + self._tg.start_soon(self._keepalive) + if not started: + task_status.started() + started = True + else: # flush cache + self.reset() + + async def _keepalive(self) -> None: + while True: + await self._connection.send_command(CommandName.PING) + await sleep(15) - @property - def confidence(self) -> float: - if not self.instance: - return self.__confidence + async def _consumer(self) -> None: + while True: + response = await self._connection.fetch_push_message(True) + messages = cast(list[StringT], response[1] or []) + for key in messages: + self._cache.invalidate(key) - return self.instance.confidence - @property - def stats(self) -> CacheStats: - return self.__stats +class ClusterTrackingCache(TrackingCache): + """ + An LRU cache for redis cluster that uses server assisted client caching + to ensure local cache entries are invalidated if any operations are performed + on the keys by another client. - def get_client_id(self, connection: BaseConnection) -> int | None: - if self.instance: - return self.instance.get_client_id(connection) + The cache maintains an additional connection per node (including replicas) + in the cluster to listen to invalidation events + """ + def get_client_id(self, connection: coredis.connection.BaseConnection) -> int | None: + if cache := self.node_caches.get(connection.location): + return cache.client_id return None - def get(self, command: bytes, key: RedisValueT, *args: RedisValueT) -> ResponseType: - assert self.instance - - return self.instance.get(command, key, *args) + def __init__(self, cache: AbstractCache | None = None) -> None: + """ """ + self.node_caches: dict[str, NodeTrackingCache] = {} + self._cache = cache or LRUCache() + self._nodes: list[coredis.client.Redis[Any]] = [] - def put( - self, command: bytes, key: RedisValueT, *args: RedisValueT, value: ResponseType + async def run( + self, pool: ConnectionPool, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED ) -> None: - if self.instance: - self.instance.put(command, key, *args, value=value) - - def invalidate(self, *keys: RedisValueT) -> None: - if self.instance: - self.instance.invalidate(*keys) - - def feedback(self, command: bytes, key: RedisValueT, *args: RedisValueT, match: bool) -> None: - if self.instance: - self.instance.feedback(command, key, *args, match=match) - - def reset(self) -> None: - if self.instance: - self.instance.reset() - - def shutdown(self) -> None: - if self.instance: - self.instance.shutdown() - self.__client = None - - def share(self) -> TrackingCache: - """ - Create a copy of this cache that can be used to share - memory with another client. - - In the example below ``c1`` and ``c2`` have their own - instances of :class:`~coredis.cache.TrackingCache` but - share the same in-memory local cached responses:: - - c1 = await coredis.Redis(cache=TrackingCache()) - c2 = await coredis.Redis(cache=c1.cache.share()) - """ - copy = self.__class__( - self.__max_keys, - self.__max_size_bytes, - self.__max_idle_seconds, - self.__confidence, - self.__dynamic_confidence, - self.__cache, - self.__stats, - ) - - return copy - - def __del__(self) -> None: - self.shutdown() + assert isinstance(pool, ClusterConnectionPool) + self._nodes = [ + pool.nodes.get_redis_link(node.host, node.port) for node in pool.nodes.all_nodes() + ] + async with AsyncExitStack() as stack: + nodes = [] + for node in self._nodes: + nodes.append(await stack.enter_async_context(node)) + + async with create_task_group() as tg: + self._task_group = tg + + for node in nodes: + node_cache = NodeTrackingCache(cache=self._cache) + await tg.start(node_cache.run, node.connection_pool) + self.node_caches[node_cache._connection.location] = node_cache + task_status.started() diff --git a/coredis/client/basic.py b/coredis/client/basic.py index ac852ba4b..757d9ec9c 100644 --- a/coredis/client/basic.py +++ b/coredis/client/basic.py @@ -1,27 +1,27 @@ from __future__ import annotations -import asyncio import contextlib import contextvars -import functools import random import warnings from collections import defaultdict from ssl import SSLContext -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any, Coroutine, cast, overload -from deprecated.sphinx import deprecated, versionadded +from anyio import AsyncContextManagerMixin, sleep +from deprecated.sphinx import versionadded +from exceptiongroup import catch from packaging import version from packaging.version import InvalidVersion, Version +from typing_extensions import Self -from coredis._utils import EncodingInsensitiveDict, nativestr -from coredis.cache import AbstractCache +from coredis._utils import EncodingInsensitiveDict, logger, nativestr +from coredis.cache import AbstractCache, NodeTrackingCache, TrackingCache from coredis.commands import CommandRequest from coredis.commands._key_spec import KeySpec from coredis.commands.constants import CommandFlag, CommandName from coredis.commands.core import CoreCommands from coredis.commands.function import Library -from coredis.commands.monitor import Monitor from coredis.commands.pubsub import PubSub, SubscriptionCallback from coredis.commands.script import Script from coredis.commands.sentinel import SentinelCommands @@ -37,7 +37,6 @@ AuthorizationError, ConnectionError, PersistenceError, - RedisError, ReplicationError, ResponseError, TimeoutError, @@ -52,16 +51,14 @@ NoopCallback, ResponseCallback, ) -from coredis.response.types import MonitorResult, ScoredMember +from coredis.response.types import ScoredMember from coredis.retry import ConstantRetryPolicy, NoRetryPolicy, RetryPolicy from coredis.typing import ( AnyStr, AsyncGenerator, AsyncIterator, Callable, - Coroutine, ExecutionParameters, - Generator, Generic, Iterator, KeyT, @@ -71,7 +68,6 @@ ParamSpec, RedisCommandP, RedisValueT, - ResponseType, StringT, T_co, TypeAdapter, @@ -85,24 +81,25 @@ if TYPE_CHECKING: import coredis.pipeline + from coredis.recipes import Lock ClientT = TypeVar("ClientT", bound="Client[Any]") RedisT = TypeVar("RedisT", bound="Redis[Any]") class Client( + AsyncContextManagerMixin, Generic[AnyStr], CoreCommands[AnyStr], ModuleMixin[AnyStr], SentinelCommands[AnyStr], ): - cache: AbstractCache | None + cache: TrackingCache | None connection_pool: ConnectionPool decode_responses: bool encoding: str - protocol_version: Literal[2, 3] server_version: Version | None - callback_storage: dict[type[ResponseCallback[Any, Any, Any]], dict[str, Any]] + callback_storage: dict[type[ResponseCallback[Any, Any]], dict[str, Any]] type_adapter: TypeAdapter def __init__( @@ -128,10 +125,8 @@ def __init__( ssl_check_hostname: bool | None = None, ssl_ca_certs: str | None = None, max_connections: int | None = None, - max_idle_time: float = 0, - idle_check_interval: float = 1, + max_idle_time: int | None = None, client_name: str | None = None, - protocol_version: Literal[2, 3] = 3, verify_version: bool = True, noreply: bool = False, retry_policy: RetryPolicy = NoRetryPolicy(), @@ -152,9 +147,7 @@ def __init__( "max_connections": max_connections, "decode_responses": decode_responses, "max_idle_time": max_idle_time, - "idle_check_interval": idle_check_interval, "client_name": client_name, - "protocol_version": protocol_version, "noreply": noreply, "noevict": noevict, "notouch": notouch, @@ -187,16 +180,6 @@ def __init__( self.connection_pool = connection_pool self.encoding = connection_pool.encoding self.decode_responses = connection_pool.decode_responses - connection_protocol_version = ( - connection_pool.connection_kwargs.get("protocol_version") or protocol_version - ) - assert connection_protocol_version in { - 2, - 3, - }, "Protocol version can only be one of {2,3}" - if connection_protocol_version == 2: - warnings.warn("Support for RESP2 will be removed in version 6.x", DeprecationWarning) - self.protocol_version = connection_protocol_version self.server_version: Version | None = None self.verify_version = verify_version self.__noreply = noreply @@ -262,11 +245,7 @@ def get_server_module_version(self, module: str) -> version.Version | None: return (self._module_info or {}).get(module) def _ensure_server_version(self, version: str | None) -> None: - if not self.verify_version or Config.optimized: - return - if not version: - return - if not self.server_version and version: + if self.verify_version and not Config.optimized and not self.server_version and version: try: self.server_version = Version(nativestr(version)) except InvalidVersion: @@ -281,53 +260,28 @@ def _ensure_server_version(self, version: str | None) -> None: self.verify_version = False self.server_version = None - async def _ensure_wait( + async def _ensure_wait_and_persist( self, command: RedisCommandP, connection: BaseConnection - ) -> asyncio.Future[None]: - maybe_wait: asyncio.Future[None] = asyncio.get_running_loop().create_future() + ) -> None: wait = self._waitcontext.get() + waitaof = self._waitaof_context.get() + wait_request = None + aof_request = None if wait and wait[0] > 0: + wait_request = await connection.create_request(CommandName.WAIT, *wait, decode=False) - def check_wait(wait: tuple[int, int], response: asyncio.Future[ResponseType]) -> None: - exc = response.exception() - if exc: - maybe_wait.set_exception(exc) - elif not cast(int, response.result()) >= wait[0]: - maybe_wait.set_exception(ReplicationError(command.name, wait[0], wait[1])) - else: - maybe_wait.set_result(None) - - request = await connection.create_request(CommandName.WAIT, *wait, decode=False) - request.add_done_callback(functools.partial(check_wait, wait)) - else: - maybe_wait.set_result(None) - return maybe_wait - - async def _ensure_persistence( - self, command: RedisCommandP, connection: BaseConnection - ) -> asyncio.Future[None]: - maybe_wait: asyncio.Future[None] = asyncio.get_running_loop().create_future() - waitaof = self._waitaof_context.get() if waitaof and waitaof[0] > 0: - - def check_wait( - waitaof: tuple[int, int, int], response: asyncio.Future[ResponseType] - ) -> None: - exc = response.exception() - if exc: - maybe_wait.set_exception(exc) - else: - res = cast(tuple[int, int], response.result()) - if not (res[0] >= waitaof[0] and res[1] >= waitaof[1]): - maybe_wait.set_exception(PersistenceError(command.name, *waitaof)) - else: - maybe_wait.set_result(None) - - request = await connection.create_request(CommandName.WAITAOF, *waitaof, decode=False) - request.add_done_callback(functools.partial(check_wait, waitaof)) - else: - maybe_wait.set_result(None) - return maybe_wait + aof_request = await connection.create_request( + CommandName.WAITAOF, *waitaof, decode=False + ) + if wait_request and wait: + wait_result = await wait_request + if not cast(int, wait_result) >= wait[0]: + raise ReplicationError(command.name, wait[0], wait[1]) + if aof_request and waitaof: + aof_result = cast(tuple[int, int], await aof_request) + if not (aof_result[0] >= waitaof[0] and aof_result[1] >= waitaof[1]): + raise PersistenceError(command.name, *waitaof) async def _populate_module_versions(self) -> None: if self.noreply or getattr(self, "_module_info", None) is not None: @@ -352,14 +306,6 @@ async def _populate_module_versions(self) -> None: ) self._module_info = {} - async def initialize(self: ClientT) -> ClientT: - await self.connection_pool.initialize() - await self._populate_module_versions() - return self - - def __await__(self: ClientT) -> Generator[Any, None, ClientT]: - return self.initialize().__await__() - def __repr__(self) -> str: return f"{type(self).__name__}<{repr(self.connection_pool)}>" @@ -613,10 +559,8 @@ def __init__( ssl_check_hostname: bool | None = ..., ssl_ca_certs: str | None = ..., max_connections: int | None = ..., - max_idle_time: float = ..., - idle_check_interval: float = ..., + max_idle_time: int | None = ..., client_name: str | None = ..., - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., cache: AbstractCache | None = ..., noreply: bool = ..., @@ -652,10 +596,8 @@ def __init__( ssl_check_hostname: bool | None = ..., ssl_ca_certs: str | None = ..., max_connections: int | None = ..., - max_idle_time: float = ..., - idle_check_interval: float = ..., + max_idle_time: int | None = ..., client_name: str | None = ..., - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., cache: AbstractCache | None = ..., noreply: bool = ..., @@ -690,10 +632,8 @@ def __init__( ssl_check_hostname: bool | None = None, ssl_ca_certs: str | None = None, max_connections: int | None = None, - max_idle_time: float = 0, - idle_check_interval: float = 1, + max_idle_time: int | None = None, client_name: str | None = None, - protocol_version: Literal[2, 3] = 3, verify_version: bool = True, cache: AbstractCache | None = None, noreply: bool = False, @@ -706,6 +646,11 @@ def __init__( """ Changes + - .. versionremoved:: 6.0.0 + - :paramref:`protocol_version` removed (and therefore support for RESP2) + + - .. versionadded:: 6.0.0 + - TODO: Add stuff - .. versionadded:: 4.12.0 - :paramref:`retry_policy` @@ -801,12 +746,7 @@ def __init__( :paramref:`connection_pool` is not ``None``. :param max_idle_time: Maximum number of a seconds an unused connection is cached before it is disconnected. - :param idle_check_interval: Periodicity of idle checks (seconds) to release idle - connections. :param client_name: The client name to identifiy with the redis server - :param protocol_version: Whether to use the RESP (``2``) or RESP3 (``3``) - protocol for parsing responses from the server (Default ``3``). - (See :ref:`handbook/response:redis response`) :param verify_version: Validate redis server version against the documented version introduced before executing a command and raises a :exc:`CommandNotSupportedError` error if the required version is higher than @@ -848,9 +788,7 @@ def __init__( ssl_ca_certs=ssl_ca_certs, max_connections=max_connections, max_idle_time=max_idle_time, - idle_check_interval=idle_check_interval, client_name=client_name, - protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, noevict=noevict, @@ -859,7 +797,7 @@ def __init__( type_adapter=type_adapter, **kwargs, ) - self.cache = cache + self.cache = NodeTrackingCache(cache=cache) if cache else None self._decodecontext: contextvars.ContextVar[bool | None,] = contextvars.ContextVar( "decode", default=None ) @@ -875,7 +813,6 @@ def from_url( db: int | None = ..., *, decode_responses: Literal[False] = ..., - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., @@ -893,7 +830,6 @@ def from_url( db: int | None = ..., *, decode_responses: Literal[True] = ..., - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., @@ -910,7 +846,6 @@ def from_url( db: int | None = None, *, decode_responses: bool = False, - protocol_version: Literal[2, 3] = 3, verify_version: bool = True, noreply: bool = False, noevict: bool = False, @@ -938,7 +873,6 @@ def from_url( if decode_responses: return cls( decode_responses=True, - protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, @@ -948,7 +882,6 @@ def from_url( url, db=db, decode_responses=decode_responses, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, @@ -958,7 +891,6 @@ def from_url( else: return cls( decode_responses=False, - protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, @@ -968,7 +900,6 @@ def from_url( url, db=db, decode_responses=decode_responses, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, @@ -976,12 +907,13 @@ def from_url( ), ) - async def initialize(self) -> Redis[AnyStr]: - if not self.connection_pool.initialized: - await super().initialize() + @contextlib.asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + async with self.connection_pool: + await self._populate_module_versions() if self.cache: - self.cache = await self.cache.initialize(self) - return self + await self.connection_pool._task_group.start(self.cache.run, self.connection_pool) + yield self async def execute_command( self, @@ -995,7 +927,6 @@ async def execute_command( """ return await self.retry_policy.call_with_retries( lambda: self._execute_command(command, callback=callback, **options), - before_hook=self.initialize, ) async def _execute_command( @@ -1005,84 +936,70 @@ async def _execute_command( **options: Unpack[ExecutionParameters], ) -> R: pool = self.connection_pool - quick_release = self.should_quick_release(command) - connection = await pool.get_connection( - command.name, - *command.arguments, - acquire=not quick_release or self.requires_wait or self.requires_waitaof, - ) - try: - keys = KeySpec.extract_keys(command.name, *command.arguments) - cacheable = ( - command.name in CACHEABLE_COMMANDS - and len(keys) == 1 - and not self.noreply - and self._decodecontext.get() is None - ) - cached_reply = None - cache_hit = False - use_cached = False - reply = None - if self.cache: - if connection.tracking_client_id != self.cache.get_client_id(connection): - self.cache.reset() - await connection.update_tracking_client( - True, self.cache.get_client_id(connection) - ) - if command.name not in READONLY_COMMANDS: - self.cache.invalidate(*keys) - elif cacheable: - try: - cached_reply = cast( - R, - self.cache.get( - command.name, - keys[0], - *command.arguments, - ), - ) - use_cached = random.random() * 100.0 < min(100.0, self.cache.confidence) - cache_hit = True - except KeyError: - pass - if not (use_cached and cached_reply): - request = await connection.create_request( - command.name, - *command.arguments, - noreply=self.noreply, - decode=options.get("decode", self._decodecontext.get()), - encoding=self._encodingcontext.get(), + async with pool.acquire() as connection: + try: + keys = KeySpec.extract_keys(command.name, *command.arguments) + cacheable = ( + command.name in CACHEABLE_COMMANDS + and len(keys) == 1 + and not self.noreply + and self._decodecontext.get() is None ) - maybe_wait = [ - await self._ensure_wait(command, connection), - await self._ensure_persistence(command, connection), - ] - reply = await request - await asyncio.gather(*maybe_wait) - if self.noreply: - return None # type: ignore - if isinstance(callback, AsyncPreProcessingCallback): - await callback.pre_process(self, reply) - if self.cache and cacheable: - if cache_hit and not use_cached: - self.cache.feedback( - command.name, keys[0], *command.arguments, match=cached_reply == reply - ) - if not cache_hit: - self.cache.put( + cached_reply = None + cache_hit = False + use_cached = False + reply = None + if self.cache: + if connection.tracking_client_id != self.cache.get_client_id(connection): + self.cache.reset() + await connection.update_tracking_client( + True, self.cache.get_client_id(connection) + ) + if command.name not in READONLY_COMMANDS: + self.cache.invalidate(*keys) + elif cacheable: + try: + cached_reply = cast( + R, + self.cache.get( + command.name, + keys[0], + *command.arguments, + ), + ) + use_cached = random.random() * 100.0 < min(100.0, self.cache.confidence) + cache_hit = True + except KeyError: + pass + if not (use_cached and cached_reply): + request = await connection.create_request( command.name, - keys[0], *command.arguments, - value=reply, + noreply=self.noreply, + decode=options.get("decode", self._decodecontext.get()), + encoding=self._encodingcontext.get(), ) - return callback(cached_reply if cache_hit else reply, version=self.protocol_version) - except RedisError: - connection.disconnect() - raise - finally: - self._ensure_server_version(connection.server_version) - if not quick_release or self.requires_wait or self.requires_waitaof: - pool.release(connection) + reply = await request + await self._ensure_wait_and_persist(command, connection) + if self.noreply: + return None # type: ignore + if isinstance(callback, AsyncPreProcessingCallback): + await callback.pre_process(self, reply) + if self.cache and cacheable: + if cache_hit and not use_cached: + self.cache.feedback( + command.name, keys[0], *command.arguments, match=cached_reply == reply + ) + if not cache_hit: + self.cache.put( + command.name, + keys[0], + *command.arguments, + value=reply, + ) + return callback(cached_reply if cache_hit else reply) + finally: + self._ensure_server_version(connection.server_version) @overload def decoding( @@ -1127,25 +1044,6 @@ def decoding(self, mode: bool, encoding: str | None = None) -> Iterator[Redis[An self._decodecontext.set(prev_decode) self._encodingcontext.set(prev_encoding) - @deprecated("The implementation of a monitor will be removed in 6.0", version="5.2.0") - def monitor( - self, - response_handler: Callable[[MonitorResult], None] | None = None, - ) -> Monitor[AnyStr]: - """ - :param response_handler: Optional callback to be triggered whenever - a command is received by this monitor. - - Return an instance of a :class:`~coredis.commands.monitor.Monitor` - - The monitor can be used as an async iterator or individual commands - can be fetched via :meth:`~coredis.commands.monitor.Monitor.get_command`. - When a :paramref:`response_handler` is provided it will simply by called - for every command received. - - """ - return Monitor[AnyStr](self, response_handler) - def pubsub( self, ignore_subscribe_messages: bool = False, @@ -1187,10 +1085,11 @@ def pubsub( **kwargs, ) - async def pipeline( + def pipeline( self, - transaction: bool | None = True, - watches: Parameters[KeyT] | None = None, + transaction: bool = True, + *, + raise_on_error: bool = True, timeout: float | None = None, ) -> coredis.pipeline.Pipeline[AnyStr]: """ @@ -1198,23 +1097,57 @@ async def pipeline( batch execution. :param transaction: indicates whether all commands should be executed atomically. - :param watches: If :paramref:`transaction` is True these keys are watched for external - changes during the transaction. + :param raise_on_error: Whether to raise errors upon executing the pipeline. + If set to `False` errors will be accumulated and retrievable from the individual + commands that had errors. :param timeout: If specified this value will take precedence over :paramref:`Redis.stream_timeout` """ from coredis.pipeline import Pipeline - return Pipeline[AnyStr](self, transaction, watches, timeout) + return Pipeline[AnyStr](self, transaction, raise_on_error, timeout) + + def lock( + self, + name: StringT, + timeout: float | None = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float | None = None, + ) -> Lock[AnyStr]: + """ + Return a lock instance which can be used to guard resource access across + multiple machines. + + :param name: key for the lock + :param timeout: indicates a maximum life for the lock. + By default, it will remain locked until :meth:`release` is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + :param sleep: indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + :param blocking: indicates whether calling :meth:`acquire` should block until + the lock has been acquired or to fail immediately, causing :meth:`acquire` + to return ``False`` and the lock not being acquired. Defaults to ``True``. + + :param blocking_timeout: indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + :class:`float` or :class:`int`, both representing the number of seconds to wait. + """ + from coredis.recipes import Lock + + return Lock(self, name, timeout, sleep, blocking, blocking_timeout) async def transaction( self, - func: Callable[[coredis.pipeline.Pipeline[AnyStr]], Coroutine[Any, Any, Any]], + func: Callable[[coredis.pipeline.Pipeline[AnyStr]], Coroutine[Any, Any, R]], *watches: KeyT, - value_from_callable: bool = False, watch_delay: float | None = None, - **kwargs: Any, - ) -> Any | None: + ) -> R: """ Convenience method for executing the callable :paramref:`func` as a transaction while watching all keys specified in :paramref:`watches`. @@ -1223,18 +1156,14 @@ async def transaction( :class:`coredis.pipeline.Pipeline` object retrieved by calling :meth:`~coredis.Redis.pipeline`. :param watches: The keys to watch during the transaction - :param value_from_callable: Whether to return the result of transaction or the value - returned from :paramref:`func` + :param watch_delay: Time in seconds to wait after each watch error before retrying """ - async with await self.pipeline(True) as pipe: - while True: - try: + msg = "Caught WatchError in transaction, retrying..." + while True: + with catch({WatchError: lambda _: logger.warning(msg)}): + async with self.pipeline(transaction=False) as pipe: if watches: await pipe.watch(*watches) - func_value = await func(pipe) - exec_value = await pipe.execute() - return func_value if value_from_callable else exec_value - except WatchError: - if watch_delay is not None and watch_delay > 0: - await asyncio.sleep(watch_delay) - continue + return await func(pipe) + if watch_delay: + await sleep(watch_delay) diff --git a/coredis/client/cluster.py b/coredis/client/cluster.py index c4bad6851..e6de1c377 100644 --- a/coredis/client/cluster.py +++ b/coredis/client/cluster.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import contextlib import contextvars import functools @@ -11,10 +10,12 @@ from ssl import SSLContext from typing import TYPE_CHECKING, Any, cast, overload +from anyio import get_cancelled_exc_class, sleep from deprecated.sphinx import versionadded +from coredis._concurrency import gather from coredis._utils import b, hash_slot -from coredis.cache import AbstractCache +from coredis.cache import AbstractCache, ClusterTrackingCache from coredis.client.basic import Client, Redis from coredis.commands._key_spec import KeySpec from coredis.commands.constants import CommandName, NodeFlag @@ -30,7 +31,6 @@ RedisClusterException, TimeoutError, TryAgainError, - WatchError, ) from coredis.globals import CACHEABLE_COMMANDS, MODULE_GROUPS, READONLY_COMMANDS from coredis.pool import ClusterConnectionPool @@ -39,6 +39,7 @@ from coredis.retry import CompositeRetryPolicy, ConstantRetryPolicy, RetryPolicy from coredis.typing import ( AnyStr, + AsyncGenerator, AsyncIterator, Awaitable, Callable, @@ -55,6 +56,7 @@ RedisCommandP, RedisValueT, ResponseType, + Self, StringT, TypeAdapter, TypeVar, @@ -201,7 +203,6 @@ def __init__( decode_responses: Literal[False] = ..., connection_pool: ClusterConnectionPool | None = ..., connection_pool_cls: type[ClusterConnectionPool] = ..., - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., non_atomic_cross_slot: bool = ..., cache: AbstractCache | None = ..., @@ -240,7 +241,6 @@ def __init__( decode_responses: Literal[True] = ..., connection_pool: ClusterConnectionPool | None = ..., connection_pool_cls: type[ClusterConnectionPool] = ..., - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., non_atomic_cross_slot: bool = ..., cache: AbstractCache | None = ..., @@ -278,7 +278,6 @@ def __init__( decode_responses: bool = False, connection_pool: ClusterConnectionPool | None = None, connection_pool_cls: type[ClusterConnectionPool] = ClusterConnectionPool, - protocol_version: Literal[2, 3] = 3, verify_version: bool = True, non_atomic_cross_slot: bool = True, cache: AbstractCache | None = None, @@ -302,6 +301,11 @@ def __init__( """ Changes + - .. versionremoved:: 6.0.0 + - :paramref:`protocol_version` removed (and therefore support for RESP2) + + - .. versionadded:: 6.0.0 + - - .. versionadded:: 4.12.0 - :paramref:`retry_policy` @@ -410,9 +414,6 @@ def __init__( a new pool will be assigned to this client. :param connection_pool_cls: The connection pool class to use when constructing a connection pool for this instance. - :param protocol_version: Whether to use the RESP (``2``) or RESP3 (``3``) - protocol for parsing responses from the server (Default ``3``). - (See :ref:`handbook/response:redis response`) :param verify_version: Validate redis server version against the documented version introduced before executing a command and raises a :exc:`CommandNotSupportedError` error if the required version is higher than @@ -474,7 +475,6 @@ def __init__( read_from_replicas=readonly or read_from_replicas, encoding=encoding, decode_responses=decode_responses, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, @@ -491,7 +491,6 @@ def __init__( encoding=encoding, decode_responses=decode_responses, verify_version=verify_version, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, @@ -507,7 +506,7 @@ def __init__( self.__class__.RESULT_CALLBACKS.copy() ) self.non_atomic_cross_slot = non_atomic_cross_slot - self.cache = cache + self.cache = ClusterTrackingCache(cache=cache) if cache else None self._decodecontext: contextvars.ContextVar[bool | None,] = contextvars.ContextVar( "decode", default=None ) @@ -524,7 +523,6 @@ def from_url( db: int | None = ..., skip_full_coverage_check: bool = ..., decode_responses: Literal[False] = ..., - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., @@ -544,7 +542,6 @@ def from_url( db: int | None = ..., skip_full_coverage_check: bool = ..., decode_responses: Literal[True], - protocol_version: Literal[2, 3] = ..., verify_version: bool = ..., noreply: bool = ..., noevict: bool = ..., @@ -563,7 +560,6 @@ def from_url( db: int | None = None, skip_full_coverage_check: bool = False, decode_responses: bool = False, - protocol_version: Literal[2, 3] = 3, verify_version: bool = True, noreply: bool = False, noevict: bool = False, @@ -599,7 +595,6 @@ def from_url( if decode_responses: return cls( decode_responses=True, - protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, @@ -610,7 +605,6 @@ def from_url( db=db, skip_full_coverage_check=skip_full_coverage_check, decode_responses=decode_responses, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, @@ -620,7 +614,6 @@ def from_url( else: return cls( decode_responses=False, - protocol_version=protocol_version, verify_version=verify_version, noreply=noreply, retry_policy=retry_policy, @@ -631,7 +624,6 @@ def from_url( db=db, skip_full_coverage_check=skip_full_coverage_check, decode_responses=decode_responses, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, @@ -639,15 +631,16 @@ def from_url( ), ) - async def initialize(self) -> RedisCluster[AnyStr]: + @contextlib.asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: if self.refresh_table_asap: self.connection_pool.initialized = False - await self.connection_pool.initialize() - self.refresh_table_asap = False - await self._populate_module_versions() - if self.cache: - self.cache = await self.cache.initialize(self) - return self + async with self.connection_pool: + self.refresh_table_asap = False + await self._populate_module_versions() + if self.cache: + await self.connection_pool._task_group.start(self.cache.run, self.connection_pool) + yield self def __repr__(self) -> str: servers = list( @@ -694,7 +687,7 @@ def num_replicas_per_shard(self) -> int: async def _ensure_initialized(self) -> None: if not self.connection_pool.initialized or self.refresh_table_asap: - await self + await self.connection_pool.initialize() def _determine_slots( self, command: bytes, *args: RedisValueT, **options: Unpack[ExecutionParameters] @@ -729,7 +722,7 @@ def _merge_result( assert command in self.result_callbacks return cast( R, - self.result_callbacks[command](res, version=self.protocol_version, **kwargs), + self.result_callbacks[command](res, **kwargs), ) def determine_node( @@ -758,12 +751,10 @@ def determine_node( return None async def on_connection_error(self, _: BaseException) -> None: - self.connection_pool.disconnect() self.connection_pool.reset() self.refresh_table_asap = True async def on_cluster_down_error(self, _: BaseException) -> None: - self.connection_pool.disconnect() self.connection_pool.reset() self.refresh_table_asap = True @@ -816,13 +807,10 @@ async def _execute_command( **kwargs, ) - results = await asyncio.gather(*tasks.values(), return_exceptions=True) + results = await gather(*tasks.values(), return_exceptions=True) if self.noreply: return None # type: ignore - return cast( - R, - self._merge_result(command.name, dict(zip(tasks.keys(), results))), - ) + return self._merge_result(command.name, dict(zip(tasks.keys(), results))) else: node = None slots = None @@ -907,7 +895,8 @@ async def _execute_command_on_single_node( while remaining_attempts > 0: remaining_attempts -= 1 if self.refresh_table_asap and not slots: - await self + # await self + pass if asking and redirect_addr: node = self.connection_pool.nodes.nodes[redirect_addr] r = await self.connection_pool.get_connection_by_node(node) @@ -983,11 +972,7 @@ async def _execute_command_on_single_node( self.connection_pool.release(r) reply = await request - maybe_wait = [ - await self._ensure_wait(command, r), - await self._ensure_persistence(command, r), - ] - await asyncio.gather(*maybe_wait) + await self._ensure_wait_and_persist(command, r) if self.noreply: return # type: ignore else: @@ -998,7 +983,6 @@ async def _execute_command_on_single_node( ) response = callback( cached_reply if cache_hit else reply, - version=self.protocol_version, ) if self.cache and cacheable: if cache_hit and not use_cached: @@ -1016,7 +1000,7 @@ async def _execute_command_on_single_node( value=reply, ) return response - except (RedisClusterException, BusyLoadingError, asyncio.CancelledError): + except (RedisClusterException, BusyLoadingError, get_cancelled_exc_class()): raise except MovedError as e: # Reinitialize on ever x number of MovedError. @@ -1031,7 +1015,7 @@ async def _execute_command_on_single_node( self.connection_pool.nodes.slots[e.slot_id][0] = node except TryAgainError: if remaining_attempts < self.MAX_RETRIES / 2: - await asyncio.sleep(0.05) + await sleep(0.05) except AskError as e: redirect_addr, asking = f"{e.host}:{e.port}", True finally: @@ -1167,9 +1151,11 @@ def sharded_pubsub( **kwargs, ) - async def pipeline( + def pipeline( self, - transaction: bool | None = None, + transaction: bool = False, + *, + raise_on_error: bool = True, watches: Parameters[StringT] | None = None, timeout: float | None = None, ) -> coredis.pipeline.ClusterPipeline[AnyStr]: @@ -1186,70 +1172,26 @@ async def pipeline( part of the pipeline. :param transaction: indicates whether all commands should be executed atomically. + :param raise_on_error: Whether to raise errors upon executing the pipeline. + If set to `False` errors will be accumulated and retrievable from the individual + commands that had errors. :param watches: If :paramref:`transaction` is True these keys are watched for external changes during the transaction. :param timeout: If specified this value will take precedence over :paramref:`RedisCluster.stream_timeout` """ - await self.connection_pool.initialize() from coredis.pipeline import ClusterPipeline return ClusterPipeline[AnyStr]( client=self, + raise_on_error=raise_on_error, transaction=transaction, watches=watches, timeout=timeout, ) - async def transaction( - self, - func: Callable[ - [coredis.pipeline.ClusterPipeline[AnyStr]], - Coroutine[Any, Any, Any], - ], - *watches: StringT, - value_from_callable: bool = False, - watch_delay: float | None = None, - **kwargs: Any, - ) -> Any: - """ - Convenience method for executing the callable :paramref:`func` as a - transaction while watching all keys specified in :paramref:`watches`. - - :param func: callable should expect a single argument which is a - :class:`coredis.pipeline.ClusterPipeline` object retrieved by calling - :meth:`~coredis.RedisCluster.pipeline`. - :param watches: The keys to watch during the transaction. The keys should route - to the same node as the keys touched by the commands in :paramref:`func` - :param value_from_callable: Whether to return the result of transaction or the value - returned from :paramref:`func` - - .. warning:: Cluster transactions can only be run with commands that - route to the same slot. - - .. versionchanged:: 4.9.0 - - When the transaction is started with :paramref:`watches` the - :class:`~coredis.pipeline.ClusterPipeline` instance passed to :paramref:`func` - will not start queuing commands until a call to - :meth:`~coredis.pipeline.ClusterPipeline.multi` is made. This makes the cluster - implementation consistent with :meth:`coredis.Redis.transaction` - """ - async with await self.pipeline(True) as pipe: - while True: - try: - if watches: - await pipe.watch(*watches) - func_value = await func(pipe) - exec_value = await pipe.execute() - return func_value if value_from_callable else exec_value - except WatchError: - if watch_delay is not None and watch_delay > 0: - await asyncio.sleep(watch_delay) - continue - async def scan_iter( self, match: StringT | None = None, @@ -1258,8 +1200,9 @@ async def scan_iter( ) -> AsyncIterator[AnyStr]: await self._ensure_initialized() for node in self.primaries: - cursor = None - while cursor != 0: - cursor, data = await node.scan(cursor or 0, match, count, type_) - for item in data: - yield item + async with node: + cursor = None + while cursor != 0: + cursor, data = await node.scan(cursor or 0, match, count, type_) + for item in data: + yield item diff --git a/coredis/commands/__init__.py b/coredis/commands/__init__.py index 2fa98d810..36370152c 100644 --- a/coredis/commands/__init__.py +++ b/coredis/commands/__init__.py @@ -24,8 +24,7 @@ # Command wrappers from .bitfield import BitFieldOperation -from .function import Function, Library -from .monitor import Monitor +from .function import Function, Library, wraps from .pubsub import ClusterPubSub, PubSub, ShardedPubSub from .request import CommandRequest, CommandResponseT from .script import Script @@ -57,8 +56,8 @@ def create_request( "ClusterPubSub", "Function", "Library", - "Monitor", "PubSub", "Script", "ShardedPubSub", + "wraps", ] diff --git a/coredis/commands/core.py b/coredis/commands/core.py index f53f9fa05..dce988027 100644 --- a/coredis/commands/core.py +++ b/coredis/commands/core.py @@ -2,6 +2,7 @@ import datetime import itertools +from collections.abc import Callable from typing import overload from deprecated.sphinx import versionadded @@ -157,6 +158,7 @@ ResponsePrimitive, ResponseType, StringT, + T_co, ValueT, ) @@ -6167,16 +6169,15 @@ def _evalsha( sha1: StringT, keys: Parameters[KeyT] | None = None, args: Parameters[ValueT] | None = None, - ) -> CommandRequest[ResponseType]: + callback: Callable[..., T_co] = NoopCallback(), + ) -> CommandRequest[T_co]: _keys: list[KeyT] = list(keys) if keys else [] command_arguments: CommandArgList = [sha1, len(_keys), *_keys] if args: command_arguments.extend(args) - return self.create_request( - command, *command_arguments, callback=NoopCallback[ResponseType]() - ) + return self.create_request(command, *command_arguments, callback=callback) @redis_command(CommandName.EVALSHA, group=CommandGroup.SCRIPTING) def evalsha( @@ -6184,7 +6185,8 @@ def evalsha( sha1: StringT, keys: Parameters[KeyT] | None = None, args: Parameters[ValueT] | None = None, - ) -> CommandRequest[ResponseType]: + callback: Callable[..., T_co] = NoopCallback(), + ) -> CommandRequest[T_co]: """ Execute the Lua script cached by it's :paramref:`sha` ref with the key names and argument values in :paramref:`keys` and :paramref:`args`. @@ -6193,7 +6195,7 @@ def evalsha( :return: The result of the script as redis returns it """ - return self._evalsha(CommandName.EVALSHA, sha1, keys, args) + return self._evalsha(CommandName.EVALSHA, sha1, keys, args, callback=callback) @versionadded(version="3.0.0") @redis_command( @@ -6207,7 +6209,8 @@ def evalsha_ro( sha1: StringT, keys: Parameters[KeyT] | None = None, args: Parameters[ValueT] | None = None, - ) -> CommandRequest[ResponseType]: + callback: Callable[..., T_co] = NoopCallback(), + ) -> CommandRequest[T_co]: """ Read-only variant of :meth:`~Redis.evalsha` that cannot execute commands that modify data. @@ -6215,7 +6218,7 @@ def evalsha_ro( :return: The result of the script as redis returns it """ - return self._evalsha(CommandName.EVALSHA_RO, sha1, keys, args) + return self._evalsha(CommandName.EVALSHA_RO, sha1, keys, args, callback=callback) @versionadded(version="3.0.0") @redis_command( @@ -6308,7 +6311,6 @@ def script_load(self, script: StringT) -> CommandRequest[AnyStr]: :return: The SHA1 digest of the script added into the script cache """ - return self.create_request( CommandName.SCRIPT_LOAD, script, callback=AnyStrCallback[AnyStr]() ) @@ -6324,7 +6326,8 @@ def fcall( function: StringT, keys: Parameters[KeyT] | None = None, args: Parameters[ValueT] | None = None, - ) -> CommandRequest[ResponseType]: + callback: Callable[..., T_co] = NoopCallback(), + ) -> CommandRequest[T_co]: """ Invoke a function """ @@ -6336,9 +6339,7 @@ def fcall( *(args or []), ] - return self.create_request( - CommandName.FCALL, *command_arguments, callback=NoopCallback[ResponseType]() - ) + return self.create_request(CommandName.FCALL, *command_arguments, callback=callback) @versionadded(version="3.1.0") @redis_command( @@ -6352,7 +6353,8 @@ def fcall_ro( function: StringT, keys: Parameters[KeyT] | None = None, args: Parameters[ValueT] | None = None, - ) -> CommandRequest[ResponseType]: + callback: Callable[..., T_co] = NoopCallback(), + ) -> CommandRequest[T_co]: """ Read-only variant of :meth:`~coredis.Redis.fcall` """ @@ -6364,11 +6366,7 @@ def fcall_ro( *(args or []), ] - return self.create_request( - CommandName.FCALL_RO, - *command_arguments, - callback=NoopCallback[ResponseType](), - ) + return self.create_request(CommandName.FCALL_RO, *command_arguments, callback=callback) @versionadded(version="3.1.0") @redis_command( diff --git a/coredis/commands/function.py b/coredis/commands/function.py index 1b0e75dc8..b44f7fe28 100644 --- a/coredis/commands/function.py +++ b/coredis/commands/function.py @@ -4,17 +4,17 @@ import inspect import itertools import weakref -from typing import Any, ClassVar, cast +from typing import Any, ClassVar, cast, get_args, overload from deprecated.sphinx import versionadded from coredis._utils import EncodingInsensitiveDict, nativestr from coredis.commands.request import CommandRequest from coredis.exceptions import FunctionError +from coredis.response._callbacks import NoopCallback from coredis.typing import ( TYPE_CHECKING, AnyStr, - Awaitable, Callable, Generator, Generic, @@ -23,8 +23,8 @@ P, Parameters, R, - ResponseType, StringT, + T_co, TypeVar, ValueT, add_runtime_checks, @@ -139,206 +139,221 @@ def __await__(self: LibraryT) -> Generator[Any, None, LibraryT]: def __getitem__(self, function: str) -> Function[AnyStr] | None: return cast(Function[AnyStr] | None, self._functions.get(function)) - @classmethod - @versionadded(version="3.5.0") - def wraps( - cls, - function_name: str, - key_spec: list[KeyT] | None = None, - param_is_key: Callable[[inspect.Parameter], bool] = lambda p: ( - p.annotation in {"KeyT", KeyT} - ), - runtime_checks: bool = False, - readonly: bool | None = None, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, CommandRequest[R]]]: - """ - Decorator for wrapping methods of subclasses of :class:`Library` - as entry points to the functions contained in the library. This allows - exposing a strict signature instead of that which :meth:`Function.__call__` - provides. The callable being decorated should **not** have an implementation as - it will never be called. - - The main objective of the decorator is to allow you to represent a lua library of - functions as a python class having strict (and type safe) methods as entry points. - Internally the decorator separates ``keys`` from ``args`` before calling - :meth:`coredis.Redis.fcall`. - - Mapping the decorated method's arguments to key providers is done either by - using :paramref:`key_spec` or :paramref:`param_is_key`. All other parameters of the - decorated method are assumed to be ``args`` consumed by the lua function. - - - The following example demonstrates most of the functionality provided by the - decorator:: - - import coredis - from coredis.commands import Library - from coredis.typing import KeyT, RedisValueT - from typing import List - - class MyAwesomeLibrary(Library): - NAME = "mylib" - CODE = \"\"\" - #!lua name=mylib - - redis.register_function('echo', function(k, a) - return a[1] - end) - redis.register_function('ping', function() - return "PONG" - end) - redis.register_function('get', function(k, a) - return redis.call("GET", k[1]) - end) - redis.register_function('hmget', function(k, a) - local values = {} - local fields = {} - local response = {} - local i = 1 - local j = 1 - - while a[i] do - fields[j] = a[i] - i = i + 2 - j = j + 1 - end - - for idx, key in ipairs(k) do - values = redis.call("HMGET", key, unpack(fields)) - for idx, value in ipairs(values) do - if not response[idx] and value then - response[idx] = value - end - end - end - for idx, value in ipairs(fields) do - if not response[idx] then - response[idx] = a[idx*2] - end - end - return response - end) - \"\"\" - - @Library.wraps("echo") - def echo(self, value: ValueT) -> CommandRequest[RedisValueT]: ... - - @Library.wraps("ping"print(c) - ) - def ping(self) -> CommandRequest[str]: ... - - @Library.wraps("get") - def get(self, key: KeyT) -> CommandRequest[ValueT]: ... - - @Library.wraps("hmmget") - def hmmget(self, *keys: KeyT, **fields_with_values: RedisValueT): - \"\"\" - Return values of ``fields_with_values`` on a first come first serve - basis from the hashes at ``keys``. Since ``fields_with_values`` is a mapping - the keys are mapped to hash fields and the values are used - as defaults if they are not found in any of the hashes at ``keys`` - \"\"\" - ... - - client = coredis.Redis() + +@overload +def wraps( + callback: None = None, + runtime_checks: bool = ..., + readonly: bool = ..., +) -> Callable[[Callable[P, R]], Callable[P, CommandRequest[R]]]: ... + + +@overload +def wraps( + callback: Callable[..., T_co], + runtime_checks: bool = ..., + readonly: bool = ..., +) -> Callable[[Callable[P, Any]], Callable[P, CommandRequest[T_co]]]: ... + + +@versionadded(version="3.5.0") +def wraps( + callback: Callable[..., T_co] | None = None, + runtime_checks: bool = False, + readonly: bool = False, +) -> Callable[[Callable[P, Any]], Callable[P, CommandRequest[Any]]]: + """ + Decorator for wrapping methods of subclasses of :class:`Library` + as entry points to the functions contained in the library. This allows + exposing a strict signature instead of that which :meth:`Function.__call__` + provides. The callable being decorated should **not** have an implementation as + it will never be called. The name of the function decorated must match the foreign + (Lua) function's name. + + The main objective of the decorator is to allow you to represent a lua library of + functions as a python class having strict (and type safe) methods as entry points. + Internally the decorator separates ``keys`` from ``args`` before calling + :meth:`coredis.Redis.fcall`. + + Mapping the decorated method's arguments to key providers is done by type + annotations: all parameters annotated as `KeyT` will be passed as keys, and the + rest will be passed as arguments. + + The following example demonstrates most of the functionality provided by the + decorator:: + + import coredis + from coredis.commands import Library, wraps + from coredis.typing import KeyT, ValueT + + class MyAwesomeLibrary(Library): + NAME = "mylib" + CODE = \"\"\" + #!lua name=mylib + + redis.register_function('echo', function(k, a) + return a[1] + end) + redis.register_function('ping', function() + return "PONG" + end) + redis.register_function { + function_name = 'get', + callback = function(k, a) + return redis.call("GET", k[1]) + end, + flags = { 'no-writes' } -- mark as read-only + } + redis.register_function('hmmget', function(k, a) + local values = {} + local fields = {} + local response = {} + local i = 1 + local j = 1 + + while a[i] do + fields[j] = a[i] + i = i + 2 + j = j + 1 + end + + for idx, key in ipairs(k) do + values = redis.call("HMGET", key, unpack(fields)) + for idx, value in ipairs(values) do + if not response[idx] and value then + response[idx] = value + end + end + end + for idx, value in ipairs(fields) do + if not response[idx] then + response[idx] = a[idx*2] + end + end + return response + end) + \"\"\" + + @wraps() + def echo(self, value: ValueT) -> ValueT: ... + + @wraps() + def ping(self) -> bytes: ... + + @wraps(readonly=True) + def get(self, key: KeyT) -> ValueT: ... + + @wraps() + def hmmget(self, *keys: KeyT, **fields_with_values: int) -> list[ValueT]: ... + + client = coredis.Redis() + async with client: lib = await MyAwesomeLibrary(client, replace=True) await client.set("hello", "world") # True await lib.echo("hello world") # b"hello world" await lib.ping() - # b"pong" + # b"PONG" await lib.get("hello") - # b"hello" - await client.hset("k1", {"c": 3, "d": 4}) - await client.hset("k2", {"a": 1, "b": 2}) - await lib.hmmget("k1", "k2", a=-1, b=-2, c=-3, d=-4, e=-5) - # [b"1", b"2", b"3", b"4", b"-5"] + # b"world" - :param key_spec: list of parameters of the decorated method that will - be passed as the :paramref:`keys` argument to :meth:`__call__`. If provided - this parameter takes precedence over using :paramref:`param_is_key` to - determine if a parameter is a key provider. - :param param_is_key: a callable that accepts a single argument of type - :class:`inspect.Parameter` and returns ``True`` if the parameter points to a key - that should be appended to the :paramref:`__call__.keys` argument of - :meth:`__call__`. The default implementation marks a parameter as a key - provider if it is of type :data:`coredis.typing.KeyT` and is only used - if :paramref:`key_spec` is ``None``. - :param runtime_checks: Whether to enable runtime type checking of input arguments - and return values. (requires :pypi:`beartype`). If :data:`False` the function will - still get runtime type checking if the environment configuration ``COREDIS_RUNTIME_CHECKS`` - is set - for details see :ref:`handbook/typing:runtime type checking`. - :param readonly: If ``True`` forces this function to use :meth:`coredis.Redis.fcall_ro` - - :return: A function that has a signature mirroring the decorated function. - """ + async with client.pipeline(transaction=False) as pipe: + pipe.hset("k1", {"c": 3, "d": 4}) + pipe.hset("k2", {"a": 1, "b": 2}) + res = MyAwesomeLibrary(pipe).hmmget("k1", "k2", a=-1, b=-2, c=-3, d=-4, e=-5) + print(await res) + # [b"1", b"2", b"3", b"4", b"-5"] - def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, CommandRequest[R]]: - sig = inspect.signature(func) - first_arg: str = list(sig.parameters.keys())[0] - runtime_check_wrapper = add_runtime_checks if not runtime_checks else safe_beartype - key_params = ( - key_spec if key_spec else [n for n, p in sig.parameters.items() if param_is_key(p)] + :param callback: a custom callback to execute on the returned value. When provided, + the callback's type will be inferred as the return type instead of the type from + the stub. + :param runtime_checks: Whether to enable runtime type checking of input arguments + and return values. (requires :pypi:`beartype`). If :data:`False` the function will + still get runtime type checking if the environment configuration ``COREDIS_RUNTIME_CHECKS`` + is set - for details see :ref:`handbook/typing:runtime type checking`. + :param readonly: If ``True`` forces this function to use :meth:`coredis.Redis.fcall_ro` + + :return: A function that has a signature mirroring the decorated function. + """ + callback = callback or NoopCallback() + + def wrapper(func: Callable[P, Any]) -> Callable[P, CommandRequest[T_co]]: + sig = inspect.signature(func) + first_arg: str = list(sig.parameters.keys())[0] + runtime_check_wrapper = add_runtime_checks if not runtime_checks else safe_beartype + key_params = [ + n + for n, p in sig.parameters.items() + if p.annotation == "KeyT" or "KeyT" in get_args(p.annotation) + ] + arg_fetch: dict[str, Callable[..., Parameters[Any]]] = { + n: ( + (lambda v: [v]) + if p.kind + in { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + } + else ( + (lambda v: list(itertools.chain.from_iterable(v.items()))) + if p.kind == inspect.Parameter.VAR_KEYWORD + else lambda v: list(v) + ) ) - arg_fetch: dict[str, Callable[..., Parameters[Any]]] = { - n: ( - (lambda v: [v]) - if p.kind - in { - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - } - else ( - (lambda v: list(itertools.chain.from_iterable(v.items()))) - if p.kind == inspect.Parameter.VAR_KEYWORD - else lambda v: list(v) - ) + for n, p in sig.parameters.items() + } + + def split_args( + *a: P.args, **k: P.kwargs + ) -> tuple[Library[AnyStr], Parameters[KeyT], Parameters[ValueT]]: + bound_arguments = sig.bind(*a, **k) + bound_arguments.apply_defaults() + arguments: dict[str, Any] = bound_arguments.arguments + instance = arguments.pop(first_arg) + if not isinstance(instance, Library): + raise RuntimeError( + f"{instance.__class__.__name__} is not a subclass of" + " coredis.commands.function.Library therefore it's methods cannot be bound " + " to a redis library using ``Library.wrap``." + " Please refer to the documentation at https://coredis.readthedocs.org/" + " for instructions on how to bind a class to a redis library." ) - for n, p in sig.parameters.items() - } - - def split_args( - *a: P.args, **k: P.kwargs - ) -> tuple[Library[AnyStr], Parameters[KeyT], Parameters[ValueT]]: - bound_arguments = sig.bind(*a, **k) - bound_arguments.apply_defaults() - arguments: dict[str, Any] = bound_arguments.arguments - instance: Library[AnyStr] = arguments.pop(first_arg) - if not isinstance(instance, Library): - raise RuntimeError( - f"{instance.__class__.__name__} is not a subclass of" - " coredis.commands.function.Library therefore it's methods cannot be bound " - " to a redis library using ``Library.wrap``." - " Please refer to the documentation at https://coredis.readthedocs.org/" - " for instructions on how to bind a class to a redis library." - ) - keys: list[KeyT] = [] - args: list[ValueT] = [] - for name in sig.parameters: - if name == first_arg: - continue - values = arg_fetch[name](arguments[name]) - if name in key_params: - keys.extend(values) - else: - args.extend(values) - return instance, keys, args - - @runtime_check_wrapper - @functools.wraps(func) - def _inner(*args: P.args, **kwargs: P.kwargs) -> CommandRequest[R]: - instance, keys, arguments = split_args(*args, **kwargs) - if (func := instance.functions.get(function_name, None)) is None: + keys: list[KeyT] = [] + args: list[ValueT] = [] + for name in sig.parameters: + if name == first_arg: + continue + values = arg_fetch[name](arguments[name]) + if name in key_params: + keys.extend(values) + else: + args.extend(values) + return instance, keys, args + + @runtime_check_wrapper + @functools.wraps(func) + def _inner(*args: P.args, **kwargs: P.kwargs) -> CommandRequest[T_co]: + instance, keys, arguments = split_args(*args, **kwargs) + if (fn := instance.functions.get(func.__name__, None)) is None: + if not hasattr(instance.client, "clear"): raise AttributeError( - f"Library {instance.name} has no registered function {function_name}" + f"Library {instance.name} has no registered function {func.__name__}" ) - return cast(CommandRequest[R], func(keys, arguments, readonly=readonly)) + # for pipelines, optimistically assume the function is registered + if readonly: + return instance.client.fcall_ro( + func.__name__, keys or [], arguments or [], callback=callback + ) + return instance.client.fcall( + func.__name__, keys or [], arguments or [], callback=callback + ) + return fn(keys, arguments, readonly=readonly, callback=callback) - return _inner + return _inner - return wrapper + return wrapper class Function(Generic[AnyStr]): @@ -373,21 +388,15 @@ def client(self) -> coredis.client.Client[AnyStr]: assert c return c - async def initialize(self) -> Function[AnyStr]: - await self.library - return self - - def __await__(self) -> Generator[Any, None, Function[AnyStr]]: - return self.initialize().__await__() - def __call__( self, keys: Parameters[KeyT] | None = None, args: Parameters[ValueT] | None = None, + callback: Callable[..., T_co] = NoopCallback(), *, client: coredis.client.Client[AnyStr] | None = None, readonly: bool | None = None, - ) -> CommandRequest[ResponseType]: + ) -> CommandRequest[T_co]: """ Wrapper to call :meth:`~coredis.Redis.fcall` with the function named :paramref:`Function.name` registered under @@ -403,6 +412,6 @@ def __call__( readonly = self.readonly if readonly: - return client.fcall_ro(self.name, keys or [], args or []) + return client.fcall_ro(self.name, keys or [], args or [], callback=callback) else: - return client.fcall(self.name, keys or [], args or []) + return client.fcall(self.name, keys or [], args or [], callback=callback) diff --git a/coredis/commands/monitor.py b/coredis/commands/monitor.py deleted file mode 100644 index 244361537..000000000 --- a/coredis/commands/monitor.py +++ /dev/null @@ -1,168 +0,0 @@ -from __future__ import annotations - -import asyncio -from types import TracebackType -from typing import TYPE_CHECKING, Any - -from deprecated.sphinx import deprecated - -from coredis.commands.constants import CommandName -from coredis.exceptions import ConnectionError, RedisError -from coredis.response.types import MonitorResult -from coredis.typing import AnyStr, Callable, Generator, Generic, Self, TypeVar - -if TYPE_CHECKING: - import coredis.client - import coredis.connection - -MonitorT = TypeVar("MonitorT", bound="Monitor[Any]") - - -@deprecated("The implementation of a monitor will be removed in 6.0", version="5.2.0") -class Monitor(Generic[AnyStr]): - """ - Monitor is useful for handling the ``MONITOR`` command to the redis server. - - It can be used as an infinite async iterator:: - - async with client.monitor() as monitor: - async for command in monitor: - print(command.time, command.client_type, command.command, command.args) - - Alternatively, each command can be fetched explicitly:: - - monitor = client.monitor() - command1 = await monitor.get_command() - command2 = await monitor.get_command() - await monitor.aclose() - - If you are only interested in triggering callbacks when a command is received - by the monitor:: - def monitor_handler(result: MonitorResult) -> None: - .... - - monitor = await client.monitor(response_handler=monitor_handler) - # when done - await monitor.aclose() - """ - - def __init__( - self, - client: coredis.client.Client[AnyStr], - response_handler: Callable[[MonitorResult], None] | None = None, - ): - """ - :param client: a Redis client - :param response_handler: optional callback to call whenever a - command is received by the monitor - """ - self.client: coredis.client.Client[AnyStr] = client - self.encoding = client.encoding - self.connection: coredis.connection.Connection | None = None - self.monitoring = False - self._monitor_results: asyncio.Queue[MonitorResult] = asyncio.Queue() - self._monitor_task: asyncio.Task[None] | None = None - self._response_handler = response_handler - - def __aiter__(self) -> Monitor[AnyStr]: - return self - - async def __anext__(self) -> MonitorResult: - """ - Infinite iterator that streams back the next command processed by the - monitored server. - """ - return await self.get_command() - - def __await__(self: MonitorT) -> Generator[Any, None, MonitorT]: - return self.__start_monitor().__await__() - - async def __aenter__(self) -> Self: - await self.__start_monitor() - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - await self.aclose() - - async def get_command(self) -> MonitorResult: - """ - Wait for the next command issued and return the details - """ - await self.__start_monitor() - return await self._monitor_results.get() - - async def aclose(self) -> None: - """ - Stop monitoring by issuing a ``RESET`` command - and release the connection. - """ - return await self.__stop_monitoring() - - @deprecated("Use :meth:`aclose` instead", version="4.21.0") - async def stop(self) -> None: - """ - Stop monitoring by issuing a ``RESET`` command - and release the connection. - """ - return await self.aclose() - - async def __connect(self) -> None: - if self.connection is None: - self.connection = await self.client.connection_pool.get_connection() - - async def __start_monitor(self: MonitorT) -> MonitorT: - if self.monitoring: - return self - await self.__connect() - assert self.connection - request = await self.connection.create_request(CommandName.MONITOR, decode=False) - response = await request - if not response == b"OK": # noqa - raise RedisError(f"Failed to start MONITOR {response!r}") - if not self._monitor_task or self._monitor_task.done(): - self._monitor_task = asyncio.create_task(self._monitor()) - self.monitoring = True - return self - - async def __stop_monitoring(self) -> None: - if self.connection: - request = await self.connection.create_request(CommandName.RESET, decode=False) - response = await request - if not response == CommandName.RESET: # noqa - raise RedisError("Failed to reset connection") - self.__reset() - - def __reset(self) -> None: - if self.connection: - self.connection.disconnect() - self.client.connection_pool.release(self.connection) - if self._monitor_task and not self._monitor_task.done(): - try: - self._monitor_task.cancel() - except RuntimeError: # noqa - pass - self.monitoring = False - self.connection = None - - async def _monitor(self) -> None: - while self.connection: - try: - response = await self.connection.fetch_push_message(block=True) - if isinstance(response, bytes): - response = response.decode(self.encoding) - assert isinstance(response, str) - result = MonitorResult.parse_response_string(response) - if self._response_handler: - self._response_handler(result) - else: - self._monitor_results.put_nowait(result) - except asyncio.CancelledError: - break - except ConnectionError: - break - self.__reset() diff --git a/coredis/commands/pubsub.py b/coredis/commands/pubsub.py index 4fedf14df..f33b04dc6 100644 --- a/coredis/commands/pubsub.py +++ b/coredis/commands/pubsub.py @@ -1,21 +1,37 @@ from __future__ import annotations -import asyncio import inspect -from asyncio import CancelledError -from contextlib import suppress -from functools import partial -from types import TracebackType -from typing import TYPE_CHECKING, Any, cast - -import async_timeout +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast + +from anyio import ( + TASK_STATUS_IGNORED, + AsyncContextManagerMixin, + ConnectionFailed, + EndOfStream, + Event, + create_memory_object_stream, + create_task_group, + current_time, + fail_after, + move_on_after, + sleep, +) +from anyio.abc import TaskStatus +from anyio.streams.stapled import StapledObjectStream from deprecated.sphinx import versionadded +from exceptiongroup import catch -from coredis._enum import CaseAndEncodingInsensitiveEnum -from coredis._utils import b, hash_slot, nativestr +from coredis._utils import b, hash_slot, logger, nativestr from coredis.commands.constants import CommandName from coredis.connection import BaseConnection, Connection -from coredis.exceptions import ConnectionError, PubSubError, TimeoutError +from coredis.exceptions import RETRYABLE, ConnectionError, PubSubError, TimeoutError +from coredis.parser import ( + PUBLISH_MESSAGE_TYPES, + SUBUNSUB_MESSAGE_TYPES, + UNSUBSCRIBE_MESSAGE_TYPES, + PubSubMessageTypes, +) from coredis.response.types import PubSubMessage from coredis.retry import ( CompositeRetryPolicy, @@ -27,13 +43,11 @@ AnyStr, Awaitable, Callable, - Generator, Generic, Mapping, MutableMapping, Parameters, RedisValueT, - ResponsePrimitive, ResponseType, Self, StringT, @@ -41,48 +55,16 @@ ) if TYPE_CHECKING: - import coredis.client - import coredis.connection import coredis.pool T = TypeVar("T") - - PoolT = TypeVar("PoolT", bound="coredis.pool.ConnectionPool") - #: Callables for message handler callbacks. The callbacks #: can be sync or async. SubscriptionCallback = Callable[[PubSubMessage], Awaitable[None]] | Callable[[PubSubMessage], None] -class PubSubMessageTypes(CaseAndEncodingInsensitiveEnum): - MESSAGE = b"message" - PMESSAGE = b"pmessage" - SMESSAGE = b"smessage" - SUBSCRIBE = b"subscribe" - UNSUBSCRIBE = b"unsubscribe" - PSUBSCRIBE = b"psubscribe" - PUNSUBSCRIBE = b"punsubscribe" - SSUBSCRIBE = b"ssubscribe" - SUNSUBSCRIBE = b"sunsubscribe" - - -class BasePubSub(Generic[AnyStr, PoolT]): - PUBLISH_MESSAGE_TYPES = { - PubSubMessageTypes.MESSAGE.value, - PubSubMessageTypes.PMESSAGE.value, - } - SUBUNSUB_MESSAGE_TYPES = { - PubSubMessageTypes.SUBSCRIBE.value, - PubSubMessageTypes.PSUBSCRIBE.value, - PubSubMessageTypes.UNSUBSCRIBE.value, - PubSubMessageTypes.PUNSUBSCRIBE.value, - } - UNSUBSCRIBE_MESSAGE_TYPES = { - PubSubMessageTypes.UNSUBSCRIBE.value, - PubSubMessageTypes.PUNSUBSCRIBE.value, - } - +class BasePubSub(AsyncContextManagerMixin, Generic[AnyStr, PoolT]): channels: MutableMapping[StringT, SubscriptionCallback | None] patterns: MutableMapping[StringT, SubscriptionCallback | None] @@ -98,11 +80,11 @@ def __init__( channel_handlers: Mapping[StringT, SubscriptionCallback] | None = None, patterns: Parameters[StringT] | None = None, pattern_handlers: Mapping[StringT, SubscriptionCallback] | None = None, + max_buffer_size: int = 1024, ): - self.initialized = False self.connection_pool = connection_pool self.ignore_subscribe_messages = ignore_subscribe_messages - self.connection: coredis.connection.Connection | None = None + self._connection: coredis.BaseConnection | None = None self._retry_policy = retry_policy or NoRetryPolicy() self._initial_channel_subscriptions = { **{nativestr(channel): None for channel in channels or []}, @@ -112,44 +94,97 @@ def __init__( **{nativestr(pattern): None for pattern in patterns or []}, **{nativestr(k): v for k, v in (pattern_handlers or {}).items()}, } - self._message_queue: asyncio.Queue[PubSubMessage | None] = asyncio.Queue() - self._consumer_task: asyncio.Task[None] | None = None - self._subscribed = asyncio.Event() - self.reset() + self._send_stream, self._receive_stream = create_memory_object_stream[PubSubMessage | None]( + max_buffer_size=max_buffer_size + ) + self._subscribed = Event() + self.channels = {} + self.patterns = {} + self.tries = 0 + + @property + def connection(self) -> BaseConnection: + if not self._connection: + raise Exception("Connection not initialized correctly!") + return self._connection @property def subscribed(self) -> bool: """Indicates if there are subscriptions to any channels or patterns""" return bool(self.channels or self.patterns) - async def initialize(self) -> Self: - """ - Ensures the pubsub instance is ready to consume messages - by establishing a connection to the redis server, setting up any - initial channel or pattern subscriptions that were specified during - instantiation and starting the consumer background task. - - The method can be called multiple times without any - risk as it will skip initialization if the consumer is already - initialized. + def __aiter__(self) -> Self: + return self - .. important:: This method doesn't need to be called explicitly - as it will always be called internally before any relevant - documented interaction. + async def __anext__(self) -> PubSubMessage: + while self._subscribed.is_set(): + if message := await self.get_message(): + return message + raise StopAsyncIteration() - :return: the instance itself - """ - if not self.initialized: - self.connection = await self.connection_pool.get_connection() - self.initialized = True + @asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + # auto-reconnection for long-lived pubsub instances + async with create_task_group() as tg: + await tg.start(self.run) + # initialize subscriptions if self._initial_channel_subscriptions: await self.subscribe(**self._initial_channel_subscriptions) if self._initial_pattern_subscriptions: await self.psubscribe(**self._initial_pattern_subscriptions) - self.connection.register_connect_callback(self.on_connect) - if not self._consumer_task or self._consumer_task.done(): - self._consumer_task = asyncio.create_task(self._consumer()) - return self + yield self + # cleanup + await self.unsubscribe() + await self.punsubscribe() + self.channels.clear() + self.patterns.clear() + self._current_scope.cancel() + + async def run(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + start_time, started, tries = current_time(), False, 0 + + def handle_error(*args: Any) -> None: + nonlocal tries, start_time + if current_time() - start_time > 10: + tries = 0 + else: + tries += 1 + logger.warning("Cache connection lost, retrying...") + + while True: + # retry with exponential backoff + await sleep(min(tries**2, 300)) + with catch({RETRYABLE: handle_error}): + async with self.connection_pool.acquire() as self._connection: + async with create_task_group() as tg: + self._current_scope = tg.cancel_scope + tg.start_soon(self._consumer) + tg.start_soon(self._keepalive) + if not started: + task_status.started() + started = True + else: # resubscribe + if self.channels: + await self.subscribe(*self.channels.keys()) + if self.patterns: + await self.psubscribe(*self.patterns.keys()) + break + + async def _keepalive(self) -> None: + while True: + await self.connection.send_command(CommandName.PING) + await sleep(15) + + async def _consumer(self) -> None: + while True: + if self._subscribed.is_set(): + if response := await self._retry_policy.call_with_retries( + lambda: self.parse_response(block=True), + ): + msg = await self.handle_message(response) + self._send_stream.send_nowait(msg) + else: + await self._subscribed.wait() async def psubscribe( self, @@ -201,9 +236,6 @@ async def subscribe( for channel, handler in channel_handlers.items(): new_channels[self.encode(channel)] = handler await self.execute_command(CommandName.SUBSCRIBE, *new_channels.keys()) - # update the channels dict AFTER we send the command. we don't want to - # subscribe twice to these channels, once for the command and again - # for the reconnection. self.channels.update(new_channels) self._subscribed.set() @@ -229,37 +261,11 @@ async def get_message( on the connection. If the ``None`` the command will block forever. """ - try: - await self.initialize() - async with async_timeout.timeout(timeout): - return self._filter_ignored_messages( - await self._message_queue.get(), ignore_subscribe_messages - ) - except asyncio.TimeoutError: - return None - - async def on_connect(self, connection: BaseConnection) -> None: - """ - Re-subscribe to any channels and patterns previously subscribed to - - :meta private: - """ - - if self.channels: - await self.subscribe( - **{ - k.decode(self.connection_pool.encoding) if isinstance(k, bytes) else k: v - for k, v in self.channels.items() - } - ) - - if self.patterns: - await self.psubscribe( - **{ - k.decode(self.connection_pool.encoding) if isinstance(k, bytes) else k: v - for k, v in self.patterns.items() - } + with move_on_after(timeout): + return self._filter_ignored_messages( + await self._receive_stream.receive(), ignore_subscribe_messages ) + return None def encode(self, value: StringT) -> StringT: """ @@ -278,46 +284,27 @@ def encode(self, value: StringT) -> StringT: async def execute_command( self, command: bytes, *args: RedisValueT, **options: RedisValueT - ) -> ResponseType | None: + ) -> None: """ Executes a publish/subscribe command :meta private: """ - await self.initialize() - - if self.connection is None: - self.connection = await self.connection_pool.get_connection() - self.connection.register_connect_callback(self.on_connect) - assert self.connection - return await self._execute(self.connection, self.connection.send_command, command, *args) + await self.connection.send_command(command, *args) async def parse_response( self, block: bool = True, timeout: float | None = None - ) -> ResponseType: + ) -> list[ResponseType]: """ Parses the response from a publish/subscribe command :meta private: """ - await self.initialize() - - assert self.connection - coro = self._execute( - self.connection, - partial( - self.connection.fetch_push_message, - block=block, - push_message_types=self.SUBUNSUB_MESSAGE_TYPES | self.PUBLISH_MESSAGE_TYPES, - ), - ) - - try: - return await asyncio.wait_for(coro, timeout if (timeout and timeout > 0) else None) - except asyncio.TimeoutError: - return None + timeout = timeout if timeout and timeout > 0 else None + with fail_after(timeout): + return await self.connection.fetch_push_message(block=block) - async def handle_message(self, response: ResponseType) -> PubSubMessage | None: + async def handle_message(self, response: list[ResponseType]) -> PubSubMessage | None: """ Parses a pub/sub message. If the channel or pattern was subscribed to with a message handler, the handler is invoked instead of a parsed @@ -325,49 +312,48 @@ async def handle_message(self, response: ResponseType) -> PubSubMessage | None: :meta private: """ - r = cast(list[ResponsePrimitive], response) - message_type = b(r[0]) - message_type_str = nativestr(r[0]) + message_type = b(response[0]) + message_type_str = nativestr(response[0]) message: PubSubMessage - if message_type in self.SUBUNSUB_MESSAGE_TYPES: + if message_type in SUBUNSUB_MESSAGE_TYPES: message = PubSubMessage( type=message_type_str, - pattern=cast(StringT, r[1]) if message_type[0] == ord(b"p") else None, + pattern=cast(StringT, response[1]) if message_type[0] == ord(b"p") else None, # This field is populated in all cases for backward compatibility # as older versions were incorrectly populating the channel # with the pattern on psubscribe/punsubscribe responses. - channel=cast(StringT, r[1]), - data=cast(int, r[2]), + channel=cast(StringT, response[1]), + data=cast(int, response[2]), ) - elif message_type in self.PUBLISH_MESSAGE_TYPES: + elif message_type in PUBLISH_MESSAGE_TYPES: if message_type == PubSubMessageTypes.PMESSAGE: message = PubSubMessage( type="pmessage", - pattern=cast(StringT, r[1]), - channel=cast(StringT, r[2]), - data=cast(StringT, r[3]), + pattern=cast(StringT, response[1]), + channel=cast(StringT, response[2]), + data=cast(StringT, response[3]), ) else: message = PubSubMessage( type="message", pattern=None, - channel=cast(StringT, r[1]), - data=cast(StringT, r[2]), + channel=cast(StringT, response[1]), + data=cast(StringT, response[2]), ) else: raise PubSubError(f"Unknown message type {message_type_str}") # noqa # if this is an unsubscribe message, remove it from memory - if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type in UNSUBSCRIBE_MESSAGE_TYPES: if message_type == PubSubMessageTypes.PUNSUBSCRIBE: subscribed_dict = self.patterns else: subscribed_dict = self.channels subscribed_dict.pop(message["channel"], None) - if message_type in self.PUBLISH_MESSAGE_TYPES: + if message_type in PUBLISH_MESSAGE_TYPES: handler = None if message_type == PubSubMessageTypes.PMESSAGE and message["pattern"]: handler = self.patterns.get(message["pattern"], None) @@ -380,24 +366,10 @@ async def handle_message(self, response: ResponseType) -> PubSubMessage | None: await handler_response return None if not (self.channels or self.patterns): - self._subscribed.clear() + self._subscribed = Event() return message - async def _consumer(self) -> None: - while self.initialized: - try: - if self.subscribed: - if response := await self._retry_policy.call_with_retries( - lambda: self.parse_response(block=True), - failure_hook=self.reset_connections, - ): - self._message_queue.put_nowait(await self.handle_message(response)) - else: - await self._subscribed.wait() - except ConnectionError: - await asyncio.sleep(0) - def _filter_ignored_messages( self, message: PubSubMessage | None, @@ -405,96 +377,12 @@ def _filter_ignored_messages( ) -> PubSubMessage | None: if ( message - and b(message["type"]) in self.SUBUNSUB_MESSAGE_TYPES + and b(message["type"]) in SUBUNSUB_MESSAGE_TYPES and (self.ignore_subscribe_messages or ignore_subscribe_messages) ): return None return message - async def _execute( - self, - connection: BaseConnection, - command: Callable[..., Awaitable[None]] | Callable[..., Awaitable[ResponseType]], - *args: RedisValueT, - ) -> ResponseType | None: - try: - return await command(*args) - except asyncio.CancelledError: - # do not retry if coroutine is cancelled - if await connection.can_read(): # noqa - connection.disconnect() - raise - - def __await__(self) -> Generator[Any, None, Self]: - return self.initialize().__await__() - - def __aiter__(self) -> Self: - return self - - async def __anext__(self) -> PubSubMessage: - await self.initialize() - while self.subscribed: - if message := await self.get_message(): - return message - else: - continue - raise StopAsyncIteration() - - async def __aenter__(self) -> Self: - await self.initialize() - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - await self.aclose() - - async def aclose(self) -> None: - """ - Unsubscribe from any channels or patterns & close and return - connections to the pool - """ - if self.connection: - await self.unsubscribe() - await self.punsubscribe() - self.close() - - def close(self) -> None: - self.reset() - - def __del__(self) -> None: - self.reset() - - def reset(self) -> None: - """ - Clear subscriptions and disconnect and release any - connection(s) back to the connection pool. - - :meta private: - """ - if self.connection: - self.connection.disconnect() - self.connection.clear_connect_callbacks() - self.connection_pool.release(self.connection) - self.connection = None - if self._consumer_task: - try: - self._consumer_task.cancel() - except RuntimeError: # noqa - pass - self._consumer_task = None - - self.channels = {} - self.patterns = {} - self.initialized = False - self._subscribed.clear() - - async def reset_connections(self, exc: BaseException | None = None) -> None: - pass - class PubSub(BasePubSub[AnyStr, "coredis.pool.ConnectionPool"]): """ @@ -558,51 +446,36 @@ class ClusterPubSub(BasePubSub[AnyStr, "coredis.pool.ClusterConnectionPool"]): """ - async def execute_command( - self, command: bytes, *args: RedisValueT, **options: RedisValueT - ) -> ResponseType | None: - await self.initialize() - assert self.connection - return await self._execute(self.connection, self.connection.send_command, command, *args) - - async def initialize(self) -> Self: - """ - Ensures the pubsub instance is ready to consume messages - by establishing a connection to a random cluster node, setting up any - initial channel or pattern subscriptions that were specified during - instantiation and starting the consumer background task. - - The method can be called multiple times without any - risk as it will skip initialization if the consumer is already - initialized. - - .. important:: This method doesn't need to be called explicitly - as it will always be called internally before any relevant - documented interaction. + async def run(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + start_time, started, tries = current_time(), False, 0 - :return: the instance itself - """ - if not self.initialized: - if self.connection is None: - await self.reset_connections(None) - self.initialized = True - if self._initial_channel_subscriptions: - await self.subscribe(**self._initial_channel_subscriptions) - if self._initial_pattern_subscriptions: - await self.psubscribe(**self._initial_pattern_subscriptions) - if not self._consumer_task or self._consumer_task.done(): - self._consumer_task = asyncio.create_task(self._consumer()) - return self - - async def reset_connections(self, exc: BaseException | None = None) -> None: - if self.connection: - self.connection.disconnect() - self.connection_pool.initialized = False - - await self.connection_pool.initialize() + def handle_error(*args: Any) -> None: + nonlocal tries, start_time + if current_time() - start_time > 10: + tries = 0 + else: + tries += 1 - self.connection = await self.connection_pool.get_connection(b"pubsub") - self.connection.register_connect_callback(self.on_connect) + while True: + await sleep(min(tries**2, 300)) + with catch({(ConnectionError, ConnectionFailed, EndOfStream): handle_error}): + self._connection = await self.connection_pool.get_connection( + command_name=b"pubsub", acquire=True + ) + async with create_task_group() as tg: + self._current_scope = tg.cancel_scope + tg.start_soon(self._consumer) + tg.start_soon(self._keepalive) + if not started: + task_status.started() + started = True + else: # resubscribe + if self.channels: + await self.subscribe(*self.channels.keys()) + if self.patterns: + await self.psubscribe(*self.patterns.keys()) + self.connection_pool.release(self._connection) + break @versionadded(version="3.6.0") @@ -626,16 +499,6 @@ class ShardedPubSub(BasePubSub[AnyStr, "coredis.pool.ClusterConnectionPool"]): For more details see :ref:`handbook/pubsub:sharded pub/sub` """ - PUBLISH_MESSAGE_TYPES = { - PubSubMessageTypes.MESSAGE.value, - PubSubMessageTypes.SMESSAGE.value, - } - SUBUNSUB_MESSAGE_TYPES = { - PubSubMessageTypes.SSUBSCRIBE.value, - PubSubMessageTypes.SUNSUBSCRIBE.value, - } - UNSUBSCRIBE_MESSAGE_TYPES = {PubSubMessageTypes.SUNSUBSCRIBE.value} - def __init__( self, connection_pool: coredis.pool.ClusterConnectionPool, @@ -646,9 +509,11 @@ def __init__( channel_handlers: Mapping[StringT, SubscriptionCallback] | None = None, ): self.shard_connections: dict[str, Connection] = {} - self.channel_connection_mapping: dict[StringT, Connection] = {} - self.pending_tasks: dict[str, asyncio.Task[ResponseType]] = {} + self.node_channel_mapping: dict[str, list[StringT]] = {} self.read_from_replicas = read_from_replicas + self._shard_messages = StapledObjectStream( + *create_memory_object_stream[list[ResponseType]]() + ) super().__init__( connection_pool, ignore_subscribe_messages, @@ -671,14 +536,13 @@ async def subscribe( :meth:`get_message`. """ - await self.initialize() new_channels: MutableMapping[StringT, SubscriptionCallback | None] = {} new_channels.update(dict.fromkeys(map(self.encode, channels))) for channel, handler in channel_handlers.items(): new_channels[self.encode(channel)] = handler for new_channel in new_channels.keys(): - await self.execute_command(CommandName.SSUBSCRIBE, new_channel, sharded=True) + await self.execute_command(CommandName.SSUBSCRIBE, new_channel) self.channels.update(new_channels) self._subscribed.set() @@ -690,7 +554,7 @@ async def unsubscribe(self, *channels: StringT) -> None: """ for channel in channels or list(self.channels.keys()): - await self.execute_command(CommandName.SUNSUBSCRIBE, channel, sharded=True) + await self.execute_command(CommandName.SUNSUBSCRIBE, channel) async def psubscribe( self, @@ -714,9 +578,7 @@ async def punsubscribe(self, *patterns: StringT) -> None: async def execute_command( self, command: bytes, *args: RedisValueT, **options: RedisValueT - ) -> ResponseType | None: - await self.initialize() - + ) -> None: assert isinstance(args[0], (bytes, str)) channel = nativestr(args[0]) slot = hash_slot(b(channel)) @@ -729,177 +591,50 @@ async def execute_command( channel=channel, node_type="replica" if self.read_from_replicas else "primary", ) - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected - self.shard_connections[key].register_connect_callback(self.on_connect) - - self.channel_connection_mapping[args[0]] = self.shard_connections[key] - assert self.shard_connections[key] - return await self._execute( - self.shard_connections[key], - self.shard_connections[key].send_command, - command, - *args, - ) + self._task_group.start_soon(self._shard_listener, key) + self.node_channel_mapping.setdefault(key, []).append(args[0]) + return await self.shard_connections[key].send_command(command, *args) raise PubSubError(f"Unable to determine shard for channel {args[0]!r}") - async def initialize(self) -> Self: - """ - Ensures the sharded pubsub instance is ready to consume messages - by ensuring the connection pool is initialized, setting up any - initial channel subscriptions that were specified during - instantiation and starting the consumer background task. - - The method can be called multiple times without any - risk as it will skip initialization if the consumer is already - initialized. - - .. important:: This method doesn't need to be called explicitly - as it will always be called internally before any relevant - documented interaction. - - :return: the instance itself - """ - if not self.initialized: - await self.connection_pool.initialize() - self.initialized = True + @asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + async with create_task_group() as self._task_group: if self._initial_channel_subscriptions: await self.subscribe(**self._initial_channel_subscriptions) - if not self._consumer_task or self._consumer_task.done(): - self._consumer_task = asyncio.create_task(self._consumer()) - return self - - async def reset_connections(self, exc: BaseException | None = None) -> None: - for connection in self.shard_connections.values(): - connection.disconnect() - connection.clear_connect_callbacks() - self.connection_pool.release(connection) - self.shard_connections.clear() - for _, task in self.pending_tasks.items(): - if not task.done(): - task.cancel() - with suppress(CancelledError): - await task - self.pending_tasks.clear() - self.connection_pool.disconnect() - self.connection_pool.reset() - self.connection_pool.initialized = False - await self.connection_pool.initialize() - for channel in self.channels: - slot = hash_slot(b(channel)) - node = self.connection_pool.nodes.node_from_slot(slot) - if node and node.node_id: - key = node.node_id - self.shard_connections[key] = await self.connection_pool.get_connection( - b"pubsub", - channel=channel, - node_type="replica" if self.read_from_replicas else "primary", - ) - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected - self.shard_connections[key].register_connect_callback(self.on_connect) - self.channel_connection_mapping[channel] = self.shard_connections[key] + self._task_group.start_soon(self._consumer) + yield self + await self.unsubscribe() + self._task_group.cancel_scope.cancel() + self.reset() + + async def _shard_listener(self, node_id: str) -> None: + while True: + connection = self.shard_connections.get(node_id, None) + if not connection: + break + try: + with move_on_after(2): + message = await connection.fetch_push_message(True) + await self._shard_messages.send(message) + except (ConnectionError, ConnectionFailed, EndOfStream): + self.shard_connections.pop(node_id) + if active_channels := set(self.channels) & set(self.node_channel_mapping[node_id]): + self._task_group.start_soon(self.subscribe, *active_channels) + break async def parse_response( self, block: bool = True, timeout: float | None = None - ) -> ResponseType: - if not self.shard_connections: - raise RuntimeError( - "pubsub connection not set: did you forget to call subscribe() or psubscribe()?" - ) - result = None - # Check any stashed results first. - if self.pending_tasks: - for node_id, task in list(self.pending_tasks.items()): - self.pending_tasks.pop(node_id) - if task.done(): - result = task.result() - break - else: - done, pending = await asyncio.wait( - [task], - timeout=0.001, - return_when=asyncio.FIRST_COMPLETED, - ) - if done: - result = done.pop().result() - break - else: - task.cancel() - with suppress(CancelledError): - await task - # If there were no pending results check the shards - if not result: - broken_connections = [c for c in self.shard_connections.values() if not c.is_connected] - if broken_connections: - for connection in broken_connections: - try: - await connection.connect() - except: # noqa - raise ConnectionError("Shard connections not stable") - tasks: dict[str, asyncio.Task[ResponseType]] = { - node_id: asyncio.create_task( - connection.fetch_push_message( - push_message_types=self.SUBUNSUB_MESSAGE_TYPES | self.PUBLISH_MESSAGE_TYPES, - ), - ) - for node_id, connection in self.shard_connections.items() - if node_id not in self.pending_tasks - } - if tasks: - done, pending = await asyncio.wait( - tasks.values(), - timeout=timeout if (timeout and timeout > 0) else None, - return_when=asyncio.FIRST_COMPLETED, - ) - if done: - done_task = done.pop() - result = done_task.result() - - # Stash any other tasks for the next iteration - for task in list(done) + list(pending): - for node_id, scheduled in tasks.items(): - if task == scheduled: - self.pending_tasks[node_id] = task - return result - - async def on_connect(self, connection: BaseConnection) -> None: - """ - Re-subscribe to any channels previously subscribed to - - :meta private: - """ - for channel, handler in self.channels.items(): - if self.channel_connection_mapping[channel] == connection: - await self.subscribe( - **{ - ( - channel.decode(self.connection_pool.encoding) - if isinstance(channel, bytes) - else channel - ): handler - } - ) + ) -> list[ResponseType]: + timeout = timeout if timeout and timeout > 0 else None + with fail_after(timeout): + return await self._shard_messages.receive() def reset(self) -> None: for connection in self.shard_connections.values(): - connection.disconnect() connection.clear_connect_callbacks() self.connection_pool.release(connection) - for _, task in self.pending_tasks.items(): - task.cancel() - self.pending_tasks.clear() self.shard_connections.clear() self.channels = {} self.patterns = {} self.initialized = False - self._subscribed.clear() - - async def aclose(self) -> None: - """ - Unsubscribe from any channels & close and return - connections to the pool - """ - if self.shard_connections: - await self.unsubscribe() - self.close() + self._subscribed = Event() diff --git a/coredis/commands/script.py b/coredis/commands/script.py index eb67d3b82..4f284f8d4 100644 --- a/coredis/commands/script.py +++ b/coredis/commands/script.py @@ -4,12 +4,13 @@ import hashlib import inspect import itertools -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, get_args, overload from deprecated.sphinx import versionadded from coredis._utils import b from coredis.exceptions import NoScriptError +from coredis.response._callbacks import NoopCallback from coredis.retry import ConstantRetryPolicy, retryable from coredis.typing import ( AnyStr, @@ -23,6 +24,7 @@ RedisValueT, ResponseType, StringT, + T_co, ValueT, add_runtime_checks, safe_beartype, @@ -79,7 +81,8 @@ def __call__( args: Parameters[ValueT] | None = None, client: coredis.client.Client[AnyStr] | None = None, readonly: bool | None = None, - ) -> Awaitable[ResponseType]: + callback: Callable[..., T_co] = NoopCallback(), + ) -> Awaitable[T_co]: """ Executes the script registered in :paramref:`Script.script` using :meth:`coredis.Redis.evalsha`. Additionally, if the script was not yet @@ -110,12 +113,12 @@ def __call__( if isinstance(client, Pipeline): # make sure this script is good to go on pipeline cast(Pipeline[AnyStr], client).scripts.add(self) - return method(self.sha, keys=keys, args=args) + return method(self.sha, keys=keys, args=args, callback=callback) else: return retryable( ConstantRetryPolicy((NoScriptError,), 1, 0), failure_hook=lambda _: client.script_load(self.script), - )(method)(self.sha, keys=keys, args=args) + )(method)(self.sha, keys=keys, args=args, callback=callback) async def execute( self, @@ -131,17 +134,32 @@ async def execute( """ return await self(keys, args, client, readonly) + @overload + def wraps( + self, + callback: None = None, + client_arg: str | None = ..., + runtime_checks: bool = ..., + readonly: bool = ..., + ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ... + + @overload + def wraps( + self, + callback: Callable[..., T_co], + client_arg: str | None = ..., + runtime_checks: bool = ..., + readonly: bool = ..., + ) -> Callable[[Callable[P, Awaitable[Any]]], Callable[P, Awaitable[T_co]]]: ... + @versionadded(version="3.5.0") def wraps( self, - key_spec: list[str] | None = None, - param_is_key: Callable[[inspect.Parameter], bool] = lambda p: ( - p.annotation in {"KeyT", KeyT} - ), + callback: Callable[..., T_co] | None = None, client_arg: str | None = None, runtime_checks: bool = False, - readonly: bool | None = None, - ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + readonly: bool = False, + ) -> Any: """ Decorator for wrapping a regular python function, method or classmethod signature with a :class:`~coredis.commands.script.Script`. This allows @@ -151,10 +169,11 @@ def wraps( The main objective of the decorator is to allow you to have strict (and type safe) signatures for wrappers for lua scripts. Internally the decorator separates - ``keys`` from ``args`` before calling :meth:`coredis.Redis.evalsha`. Mapping the - decorated methods arguments to key providers is done either by using :paramref:`key_spec` - or :paramref:`param_is_key`. All other paramters of the decorated function are assumed - to be ``args`` consumed by the lua script. + ``keys`` from ``args`` before calling :meth:`coredis.Redis.evalsha`. + + Mapping the decorated method's arguments to key providers is done by type + annotations: all parameters annotated as `KeyT` will be passed as keys, and the + rest will be passed as arguments. By default the decorated method is bound to the :class:`coredis.client.Redis` or :class:`coredis.client.RedisCluster` instance that the :class:`Script` instance @@ -168,15 +187,14 @@ def wraps( passed to redis as an ``arg``:: import coredis - from coredis.typing import KeyT, RedisValueT - from typing import List + from coredis.typing import KeyT, ValueT client = coredis.Redis() @client.register_script("return {KEYS[1], ARGV[1]}").wraps() - async def echo_key_value(key: KeyT, value: RedisValueT) -> List[RedisValueT]: ... + async def echo_key_value(key: KeyT, value: ValueT) -> list[ValueT]: ... - k, v = await echo_key_value("co", "redis") - # (b"co", b"redis") + res = await echo_key_value("co", "redis") + # [b"co", b"redis"] Alternatively, the following example builds a class method that requires the ``client`` to be passed in explicitly:: @@ -203,16 +221,28 @@ def echo_arg(cls, client, value): ... echoed = await ScriptProvider.echo_value(Redis(), "coredis") # b"coredis" - :param key_spec: list of parameters of the decorated method that will - be passed as the :paramref:`keys` argument to :meth:`__call__`. If provided - this parameter takes precedence over using :paramref:`param_is_key` to determine if - a parameter is a key provider. - :param param_is_key: a callable that accepts a single argument of type - :class:`inspect.Parameter` and returns ``True`` if the parameter points - to a key that should be appended to the :paramref:`__call__.keys` argument - of :meth:`__call__`. The default implementation marks a parameter as a key - provider if it is of type :data:`coredis.typing.KeyT` and is only used if - :paramref:`key_spec` is ``None``. + You can also pass a custom callback to execute on the return type, which will + be inferred as the return type rather than the annotation:: + + class MyCallback(ResponseCallback[Any, Any, int]): + def transform(self, response: ResponseType) -> int: + return sum([ord(c) for c in str(response)]) + + client = coredis.Redis(decode_responses=True) + async with client: + script = client.register_script("return {KEYS[1], ARGV[1]}") + + # we use Any since return type will come from callback + @script.wraps(callback=MyCallback()) + async def echo_key_value(key: KeyT, value: ValueT) -> Any: ... + + res = await echo_key_value("co", "redis") + reveal_type(res) # int + # 1161 + + :param callback: a custom callback to execute on the returned value. When provided, + the callback's type will be inferred as the return type instead of the type from + the stub. :param client_arg: The parameter of the decorator that will contain a client instance to be used to execute the script. :param runtime_checks: Whether to enable runtime type checking of input arguments @@ -224,15 +254,18 @@ def echo_arg(cls, client, value): ... :return: A function that has a signature mirroring the decorated function. """ + callback = callback or NoopCallback() - def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def wrapper(func: Callable[P, Awaitable[Any]]) -> Callable[P, Awaitable[T_co]]: sig = inspect.signature(func) first_arg = list(sig.parameters.keys())[0] runtime_check_wrapper = add_runtime_checks if not runtime_checks else safe_beartype script_instance = self - key_params = ( - key_spec if key_spec else [n for n, p in sig.parameters.items() if param_is_key(p)] - ) + key_params = [ + n + for n, p in sig.parameters.items() + if p.annotation == "KeyT" or "KeyT" in get_args(p.annotation) + ] arg_fetch: dict[str, Callable[..., Parameters[Any]]] = { n: ( (lambda v: [v]) @@ -285,11 +318,9 @@ def split_args( async def __inner( *args: P.args, **kwargs: P.kwargs, - ) -> R: + ) -> T_co: keys, arguments, client = split_args(sig.bind(*args, **kwargs)) - # TODO: atleast lie with a cast. - # mypy doesn't like the cast - return await script_instance(keys, arguments, client, readonly) # type: ignore + return await script_instance(keys, arguments, client, readonly, callback=callback) # type: ignore return __inner diff --git a/coredis/connection.py b/coredis/connection.py index 1c196f287..80cfd9ace 100644 --- a/coredis/connection.py +++ b/coredis/connection.py @@ -1,31 +1,40 @@ from __future__ import annotations -import asyncio import dataclasses -import functools import inspect -import itertools +import math import os import socket import ssl -import time -import warnings -import weakref +from abc import abstractmethod from collections import defaultdict, deque -from contextlib import suppress -from typing import TYPE_CHECKING, Any, cast - -import async_timeout +from typing import TYPE_CHECKING, Any, Generator, cast + +from anyio import ( + TASK_STATUS_IGNORED, + ClosedResourceError, + Event, + Lock, + connect_tcp, + connect_unix, + create_memory_object_stream, + create_task_group, + fail_after, + move_on_after, +) +from anyio.abc import ByteStream, SocketAttribute, TaskStatus +from typing_extensions import override import coredis from coredis._packer import Packer -from coredis._utils import nativestr +from coredis._utils import logger, nativestr from coredis.credentials import ( AbstractCredentialProvider, UserPass, UserPassCredentialProvider, ) from coredis.exceptions import ( + RETRYABLE, AuthenticationRequiredError, ConnectionError, RedisError, @@ -33,17 +42,22 @@ UnknownCommandError, ) from coredis.parser import NotEnoughData, Parser +from coredis.retry import ExponentialBackoffRetryPolicy from coredis.tokens import PureToken from coredis.typing import ( Awaitable, Callable, ClassVar, - Literal, RedisValueT, ResponseType, TypeVar, ) +CERT_REQS = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, +} R = TypeVar("R") if TYPE_CHECKING: @@ -52,28 +66,46 @@ @dataclasses.dataclass class Request: - connection: weakref.ProxyType[Connection] command: bytes decode: bool encoding: str | None = None raise_exceptions: bool = True - future: asyncio.Future[ResponseType] = dataclasses.field( - default_factory=lambda: asyncio.get_running_loop().create_future() - ) - created_at: float = dataclasses.field(default_factory=lambda: time.time()) - - def __post_init__(self) -> None: - self.future.add_done_callback(self.cleanup) - - def cleanup(self, future: asyncio.Future[ResponseType]) -> None: - if future.cancelled() and self.connection and self.connection.is_connected: - self.connection.disconnect() - - def enforce_deadline(self, timeout: float) -> None: - if not self.future.done(): - self.future.set_exception( - TimeoutError(f"command {nativestr(self.command)} timed out after {timeout} seconds") + response_timeout: float | None = None + _event: Event = dataclasses.field(default_factory=Event) + _exc: BaseException | None = None + _result: ResponseType | None = None + + def __await__(self) -> Generator[Any, None, ResponseType]: + return self.get_result().__await__() + + def resolve(self, response: ResponseType) -> None: + self._result = response + self._event.set() + + def fail(self, error: BaseException) -> None: + if not self._event.is_set(): + self._exc = error + self._event.set() + + async def get_result(self) -> ResponseType: + # return now if response available + if self._event.is_set(): + return self._result_or_exc() + # add response timeout + with move_on_after(self.response_timeout) as scope: + await self._event.wait() + if scope.cancelled_caught and not self._event.is_set(): + self._exc = TimeoutError( + f"command {nativestr(self.command)} timed out after {self.response_timeout} seconds" ) + return self._result_or_exc() + + def _result_or_exc(self) -> ResponseType: + if self._exc is not None: + if self.raise_exceptions: + raise self._exc + return self._exc # type: ignore + return self._result @dataclasses.dataclass @@ -101,12 +133,6 @@ def __init__( if cert_reqs is None: self.cert_reqs = ssl.CERT_OPTIONAL elif isinstance(cert_reqs, str): - CERT_REQS = { - "none": ssl.CERT_NONE, - "optional": ssl.CERT_OPTIONAL, - "required": ssl.CERT_REQUIRED, - } - self.cert_reqs = CERT_REQS[cert_reqs] else: self.cert_reqs = cert_reqs @@ -127,29 +153,15 @@ def get(self) -> ssl.SSLContext: return self.context -class BaseConnection(asyncio.BaseProtocol): +class BaseConnection: """ - Base connection class which implements - :class:`asyncio.BaseProtocol` to interact - with the underlying connection established - with the redis server. + Base connection class which interacts with the underlying connection + established with the redis server. """ - #: id for this connection as returned by the redis server - client_id: int | None - #: Queue that collects any unread push message types - push_messages: asyncio.Queue[ResponseType] - #: client id that the redis server should send any redirected notifications to - tracking_client_id: int | None - #: Whether the connection should use RESP or RESP3 - protocol_version: Literal[2, 3] - description: ClassVar[str] = "BaseConnection" locator: ClassVar[str] = "" - #: average response time of requests made on this connection - average_response_time: float - def __init__( self, stream_timeout: float | None = None, @@ -157,10 +169,10 @@ def __init__( decode_responses: bool = False, *, client_name: str | None = None, - protocol_version: Literal[2, 3] = 3, noreply: bool = False, noevict: bool = False, notouch: bool = False, + max_idle_time: int | None = None, ): self._stream_timeout = stream_timeout self.username: str | None = None @@ -174,38 +186,34 @@ def __init__( ] = list() self.encoding = encoding self.decode_responses = decode_responses - self.protocol_version = protocol_version self.server_version: str | None = None self.client_name = client_name - self.client_id = None - self.tracking_client_id = None - - self.last_active_at: float = time.time() - self.last_request_processed_at: float | None = None - - self._transport: asyncio.Transport | None = None - self._parser = Parser() - self._read_flag = asyncio.Event() - self._read_waiters: set[asyncio.Task[bool]] = set() + #: id for this connection as returned by the redis server + self.client_id: int | None = None + #: client id that the redis server should send any redirected notifications to + self.tracking_client_id: int | None = None + + self._connection: ByteStream | None = None + #: Queue that collects any unread push message types + push_messages, self._receive_messages = create_memory_object_stream[list[ResponseType]]( + math.inf + ) + self._parser = Parser(push_messages) self.packer: Packer = Packer(self.encoding) - self.push_messages: asyncio.Queue[ResponseType] = asyncio.Queue() + self.max_idle_time = max_idle_time - self.noreply: bool = noreply - self.noreply_set: bool = False + self.noreply = noreply + self.noreply_set = False - self.noevict: bool = noevict - self.notouch: bool = notouch + self.noevict = noevict + self.notouch = notouch - self.needs_handshake: bool = True + self.needs_handshake = True self._last_error: BaseException | None = None - self._connection_error: BaseException | None = None + self._connected = False self._requests: deque[Request] = deque() - - self.average_response_time: float = 0 - self.requests_processed: int = 0 - self._write_ready: asyncio.Event = asyncio.Event() - self._transport_lock: asyncio.Lock = asyncio.Lock() + self._write_lock = Lock() def __repr__(self) -> str: return self.describe(self._description_args()) @@ -219,18 +227,10 @@ def location(self) -> str: return self.locator.format_map(defaultdict(lambda: None, self._description_args())) @property - def estimated_time_to_idle(self) -> float: - """ - Estimated time till the pending request queue of this connection - has been cleared - """ - return self.requests_pending * self.average_response_time - - def __del__(self) -> None: - try: - self.disconnect() - except Exception: # noqa - pass + def connection(self) -> ByteStream: + if not self._connection: + raise ConnectionError("Connection not initialized correctly!") + return self._connection @property def is_connected(self) -> bool: @@ -238,27 +238,7 @@ def is_connected(self) -> bool: Whether the connection is established and initial handshakes were performed without error """ - return self._transport is not None and self._connection_error is None - - @property - def requests_pending(self) -> int: - """ - Number of requests pending response on this connection - """ - return len(self._requests) - - @property - def lag(self) -> float: - """ - Returns the amount of seconds since the last request was processed - if there are still in flight requests pending on this connection - """ - if not self._requests: - return 0 - elif self.last_request_processed_at is None: - return time.time() - else: - return time.time() - self.last_request_processed_at + return self._connected def register_connect_callback( self, @@ -269,114 +249,74 @@ def register_connect_callback( def clear_connect_callbacks(self) -> None: self._connect_callbacks = list() - async def can_read(self) -> bool: - """Checks for data that can be read""" - assert self._parser - - if not self.is_connected: - await self.connect() - - return self._parser.can_read() + @abstractmethod + async def _connect(self) -> ByteStream: ... - async def connect(self) -> None: + async def run(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: """ Establish a connnection to the redis server - and initiate any post connect callbacks - """ - self._connection_error = None - try: - await self._connect() - except (asyncio.CancelledError, RedisError) as err: - self._connection_error = err - raise - except Exception as err: - self._connection_error = err - raise ConnectionError(str(err)) from err - - # run any user callbacks. right now the only internal callback - # is for pubsub channel/pattern resubscription - for callback in self._connect_callbacks: - task = callback(self) - if inspect.isawaitable(task): - await task - - def connection_made(self, transport: asyncio.BaseTransport) -> None: + and initiate any post connect callbacks. """ - :meta private: - """ - self._transport = cast(asyncio.Transport, transport) - self._write_ready.set() - - def connection_lost(self, exc: BaseException | None) -> None: - """ - :meta private: - """ - if exc: - self._last_error = exc - - self.disconnect() - - def pause_writing(self) -> None: - """ - :meta private: - """ - self._write_ready.clear() - def resume_writing(self) -> None: - """ - :meta private: - """ - self._write_ready.set() + retry = ExponentialBackoffRetryPolicy(RETRYABLE, 3, 0.5) + self._connection = await retry.call_with_retries(self._connect) + try: + async with self.connection, self._parser.push_messages, create_task_group() as tg: + tg.start_soon(self.listen_for_responses) + # setup connection + await self.on_connect() + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + for callback in self._connect_callbacks: + task = callback(self) + if inspect.isawaitable(task): + await task + self._connected = True + task_status.started() + except Exception as e: + logger.exception("Connection closed unexpectedly!") + self._last_error = e + # swallow the error unless connection hasn't been established; + # it will usually be raised when accessing command results. + # we want the connection to die, but we don't always want to + # raise it and corrupt the connection pool. + if not self._connected: + raise + finally: + self._parser.on_disconnect() + disconnect_exc = self._last_error or ConnectionError("Connection lost!") + while self._requests: + request = self._requests.popleft() + request.fail(disconnect_exc) + self._connection = None - def data_received(self, data: bytes) -> None: + async def listen_for_responses(self) -> None: """ - :meta private: + Listen on the socket and run the parser, completing pending requests in + FIFO order. """ - self._parser.feed(data) - self._read_flag.set() - if not self._requests: - return - - request = self._requests.popleft() - response = self._parser.get_response(request.decode, request.encoding) - while not isinstance( - response, - NotEnoughData, - ): - if not (request.future.cancelled() or request.future.done()): + while True: + decode = self._requests[0].decode if self._requests else self.decode_responses + # Try to parse a complete response from already-fed bytes + response = self._parser.get_response( + decode, self._requests[0].encoding if self._requests else self.encoding + ) + if isinstance(response, NotEnoughData): + # Need more bytes; read once, feed, and retry + with move_on_after(self.max_idle_time) as scope: + data = await self.connection.receive() + self._parser.feed(data) + if scope.cancelled_caught: # this will cleanup the connection gracefully + break + continue # loop back and try parsing again + + # We have a full response for `head`; now pop and complete it + if self._requests: + request = self._requests.popleft() if request.raise_exceptions and isinstance(response, RedisError): - request.future.set_exception(response) + request.fail(response) else: - request.future.set_result(response) - - self.last_request_processed_at = time.time() - self.requests_processed += 1 - response_time = time.time() - request.created_at - - self.average_response_time = ( - (self.average_response_time * (self.requests_processed - 1)) + response_time - ) / self.requests_processed - - try: - request = self._requests.popleft() - except IndexError: - return - - response = self._parser.get_response(request.decode, request.encoding) - - # In case the first request pulled from the queue doesn't have enough data - # to process, put it back to the start of the queue for the next iteration - if request: - self._requests.appendleft(request) - - def eof_received(self) -> None: - """ - :meta private: - """ - self.disconnect() - - async def _connect(self) -> None: - raise NotImplementedError + request.resolve(response) async def update_tracking_client(self, enabled: bool, client_id: int | None = None) -> bool: """ @@ -418,7 +358,7 @@ async def perform_handshake(self) -> None: if not self.needs_handshake: return - hello_command_args: list[int | str | bytes] = [self.protocol_version] + hello_command_args: list[int | str | bytes] = [3] if creds := ( await self.credential_provider.get_credentials() if self.credential_provider @@ -440,27 +380,20 @@ async def perform_handshake(self) -> None: await self.create_request(b"HELLO", *hello_command_args, decode=False) ) assert isinstance(hello_resp, (list, dict)) - if self.protocol_version == 3: - resp3 = cast(dict[bytes, RedisValueT], hello_resp) - assert resp3[b"proto"] == 3 - self.server_version = nativestr(resp3[b"version"]) - self.client_id = int(resp3[b"id"]) - else: - resp = cast(list[RedisValueT], hello_resp) - self.server_version = nativestr(resp[3]) - self.client_id = int(resp[7]) + resp3 = cast(dict[bytes, RedisValueT], hello_resp) + assert resp3[b"proto"] == 3 + self.server_version = nativestr(resp3[b"version"]) + self.client_id = int(resp3[b"id"]) if self.server_version >= "7.2": - await asyncio.gather( - await self.create_request( - b"CLIENT SETINFO", - b"LIB-NAME", - b"coredis", - ), - await self.create_request( - b"CLIENT SETINFO", - b"LIB-VER", - coredis.__version__, - ), + await self.create_request( + b"CLIENT SETINFO", + b"LIB-NAME", + b"coredis", + ) + await self.create_request( + b"CLIENT SETINFO", + b"LIB-VER", + coredis.__version__, ) self.needs_handshake = False except AuthenticationRequiredError: @@ -468,24 +401,11 @@ async def perform_handshake(self) -> None: self.server_version = None self.client_id = None except UnknownCommandError: # noqa - # This should only happen for redis servers < 6 or forks of redis - # that are not > 6 compliant. - warning = ( - "The server responded with no support for the `HELLO` command" - " and therefore a handshake could not be performed" + raise ConnectionError( + "Unable to use RESP3 due to missing `HELLO` implementation the server." ) - if self.protocol_version == 3: - raise ConnectionError( - "Unable to use RESP3 due to missing `HELLO` implementation " - "the server. Use `protocol_version=2` when constructing the client." - ) - else: - warnings.warn(warning, category=UserWarning) - await self.try_legacy_auth() - self.needs_handshake = False async def on_connect(self) -> None: - self._parser.on_connect(self) await self.perform_handshake() if self.db: @@ -509,55 +429,16 @@ async def on_connect(self) -> None: await (await self.create_request(b"CLIENT REPLY", b"OFF", noreply=True)) self.noreply_set = True - self.last_active_at = time.time() - - async def fetch_push_message( - self, - decode: RedisValueT | None = None, - push_message_types: set[bytes] | None = None, - block: bool | None = False, - ) -> ResponseType: + async def fetch_push_message(self, block: bool = False) -> list[ResponseType]: """ Read the next pending response """ - if not self.is_connected: - await self.connect() - - if len(self._requests) > 0: - raise ConnectionError( - f"Invalid request for push messages. {len(self._requests)} requests still pending" - ) + if block: + timeout = self._stream_timeout if not block else None + with fail_after(timeout): + return await self._receive_messages.receive() - message = self._parser.get_response( - bool(decode) if decode is not None else self.decode_responses, - self.encoding, - push_message_types, - ) - while isinstance( - message, - NotEnoughData, - ): - self._read_flag.clear() - try: - timeout = self._stream_timeout if not block else None - read_ready_task = asyncio.create_task(self._read_flag.wait()) - read_ready_task.add_done_callback( - lambda _: self._read_waiters.discard(read_ready_task) - ) - self._read_waiters.add(read_ready_task) - await asyncio.wait_for(read_ready_task, timeout) - except asyncio.TimeoutError: - raise TimeoutError - except asyncio.CancelledError: - if not self.is_connected: - raise ConnectionError("Connection lost") - raise - message = self._parser.get_response( - bool(decode) if decode is not None else self.decode_responses, - self.encoding, - push_message_types, - ) - return message + return self._receive_messages.receive_nowait() async def _send_packed_command( self, command: list[bytes], timeout: float | None = None @@ -565,16 +446,14 @@ async def _send_packed_command( """ Sends an already packed command to the Redis server """ - - assert self._transport - try: - async with async_timeout.timeout(timeout): - await self._write_ready.wait() - except asyncio.TimeoutError: - if self._transport: - self.disconnect() - raise TimeoutError(f"Unable to write after waiting for socket for {timeout} seconds") - self._transport.writelines(command) + with fail_after(timeout): + data = b"".join(command) + try: + await self.connection.send(data) + except ClosedResourceError as err: + self._last_error = err + self._connection = None + raise ConnectionError(f"Failed to send data: {data.decode()}!") from err async def send_command( self, @@ -584,13 +463,8 @@ async def send_command( """ Send a command to the redis server """ - - if not self.is_connected: - await self.connect() - - await self._send_packed_command(self.packer.pack_command(command, *args)) - - self.last_active_at = time.time() + async with self._write_lock: + await self._send_packed_command(self.packer.pack_command(command, *args)) async def create_request( self, @@ -601,114 +475,58 @@ async def create_request( encoding: str | None = None, raise_exceptions: bool = True, timeout: float | None = None, - ) -> asyncio.Future[ResponseType]: + ) -> Request: """ Send a command to the redis server """ from coredis.commands.constants import CommandName - if not self.is_connected: - await self.connect() - cmd_list = [] - request_timeout: float | None = timeout or self._stream_timeout if self.is_connected and noreply and not self.noreply: cmd_list = self.packer.pack_command(CommandName.CLIENT_REPLY, PureToken.SKIP) cmd_list.extend(self.packer.pack_command(command, *args)) - await self._send_packed_command(cmd_list, timeout=request_timeout) - - self.last_active_at = time.time() - - if not (self.noreply_set or noreply): - request = Request( - weakref.proxy(self), - command, - bool(decode) if decode is not None else self.decode_responses, - encoding or self.encoding, - raise_exceptions, - ) - self._requests.append(request) - if request_timeout is not None: - asyncio.get_running_loop().call_later( - request_timeout, - functools.partial( - request.enforce_deadline, - request_timeout, - ), - ) - return request.future - else: - none: asyncio.Future[ResponseType] = asyncio.Future() - none.set_result(None) - return none + request_timeout: float | None = timeout or self._stream_timeout + request = Request( + command, + bool(decode) if decode is not None else self.decode_responses, + encoding or self.encoding, + raise_exceptions, + request_timeout, + ) + async with self._write_lock: + if not (self.noreply_set or noreply): + self._requests.append(request) + else: + request.resolve(None) + await self._send_packed_command(cmd_list, timeout=request_timeout) + return request async def create_requests( self, commands: list[CommandInvocation], raise_exceptions: bool = True, timeout: float | None = None, - ) -> list[asyncio.Future[ResponseType]]: + ) -> list[Request]: """ Send multiple commands to the redis server """ - - if not self.is_connected: - await self.connect() - request_timeout: float | None = timeout or self._stream_timeout - - await self._send_packed_command( - self.packer.pack_commands( - list(itertools.chain((cmd.command, *cmd.args) for cmd in commands)) - ), - timeout=request_timeout, - ) - - self.last_active_at = time.time() - requests: list[asyncio.Future[ResponseType]] = [] - for cmd in commands: - request = Request( - weakref.proxy(self), + requests = [ + Request( cmd.command, bool(cmd.decode) if cmd.decode is not None else self.decode_responses, cmd.encoding or self.encoding, raise_exceptions, + request_timeout, ) - self._requests.append(request) - if request_timeout is not None: - asyncio.get_running_loop().call_later( - request_timeout, - functools.partial(request.enforce_deadline, request_timeout), - ) - requests.append(request.future) + for cmd in commands + ] + packed = self.packer.pack_commands([(cmd.command, *cmd.args) for cmd in commands]) + async with self._write_lock: + self._requests.extend(requests) + await self._send_packed_command(packed, timeout=request_timeout) return requests - def disconnect(self) -> None: - """ - Disconnect from the Redis server - """ - self.needs_handshake = True - self.noreply_set = False - self._parser.on_disconnect() - if self._transport: - with suppress(RuntimeError): - self._transport.close() - - disconnect_exc = self._last_error or ConnectionError("connection lost") - while self._read_waiters: - waiter = self._read_waiters.pop() - if not waiter.done(): - with suppress(RuntimeError): - waiter.cancel() - while True: - try: - request = self._requests.popleft() - if not request.future.done(): - request.future.set_exception(disconnect_exc) - except IndexError: - break - self._transport = None - class Connection(BaseConnection): description: ClassVar[str] = "Connection" @@ -731,20 +549,20 @@ def __init__( socket_keepalive_options: dict[int, int | bytes] | None = None, *, client_name: str | None = None, - protocol_version: Literal[2, 3] = 3, noreply: bool = False, noevict: bool = False, notouch: bool = False, + max_idle_time: int | None = None, ): super().__init__( stream_timeout, encoding, decode_responses, client_name=client_name, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, + max_idle_time=max_idle_time, ) self.host = host self.port = port @@ -762,41 +580,26 @@ def __init__( self.socket_keepalive = socket_keepalive self.socket_keepalive_options: dict[int, int | bytes] = socket_keepalive_options or {} - async def _connect(self) -> None: - async with self._transport_lock: - if self._transport: - return + @override + async def _connect(self) -> ByteStream: + with fail_after(self._connect_timeout): if self.ssl_context: - connection = asyncio.get_running_loop().create_connection( - lambda: self, host=self.host, port=self.port, ssl=self.ssl_context + connection: ByteStream = await connect_tcp( + self.host, + self.port, + tls=True, + ssl_context=self.ssl_context, + tls_standard_compatible=False, ) else: - connection = asyncio.get_running_loop().create_connection( - lambda: self, host=self.host, port=self.port - ) - - try: - async with async_timeout.timeout(self._connect_timeout): - transport, _ = await connection - except asyncio.TimeoutError: - raise ConnectionError( - f"Unable to establish a connection within {self._connect_timeout} seconds" - ) - sock = transport.get_extra_info("socket") + connection = await connect_tcp(self.host, self.port) + sock = connection.extra(SocketAttribute.raw_socket, default=None) if sock is not None: - try: - # TCP_KEEPALIVE - if self.socket_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - - for k, v in self.socket_keepalive_options.items(): - sock.setsockopt(socket.SOL_TCP, k, v) - except (OSError, TypeError): - # `socket_keepalive_options` might contain invalid options - # causing an error - transport.close() - raise - await self.on_connect() + if self.socket_keepalive: # TCP_KEEPALIVE + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + return connection class UnixDomainSocketConnection(BaseConnection): @@ -816,7 +619,7 @@ def __init__( decode_responses: bool = False, *, client_name: str | None = None, - protocol_version: Literal[2, 3] = 3, + max_idle_time: int | None = None, **_: RedisValueT, ) -> None: super().__init__( @@ -824,7 +627,7 @@ def __init__( encoding, decode_responses, client_name=client_name, - protocol_version=protocol_version, + max_idle_time=max_idle_time, ) self.path = path self.db = db @@ -834,11 +637,10 @@ def __init__( self._connect_timeout = connect_timeout self._description_args = lambda: {"path": self.path, "db": self.db} - async def _connect(self) -> None: - async with async_timeout.timeout(self._connect_timeout): - await asyncio.get_running_loop().create_unix_connection(lambda: self, path=self.path) - - await self.on_connect() + @override + async def _connect(self) -> ByteStream: + with fail_after(self._connect_timeout): + return await connect_unix(self.path) class ClusterConnection(Connection): @@ -865,11 +667,11 @@ def __init__( socket_keepalive_options: dict[int, int | bytes] | None = None, *, client_name: str | None = None, - protocol_version: Literal[2, 3] = 3, read_from_replicas: bool = False, noreply: bool = False, noevict: bool = False, notouch: bool = False, + max_idle_time: int | None = None, ) -> None: self.read_from_replicas = read_from_replicas super().__init__( @@ -887,20 +689,19 @@ def __init__( socket_keepalive=socket_keepalive, socket_keepalive_options=socket_keepalive_options, client_name=client_name, - protocol_version=protocol_version, noreply=noreply, noevict=noevict, notouch=notouch, + max_idle_time=max_idle_time, ) - async def on_connect(self) -> None: - """ - Initialize the connection, authenticate and select a database and send - `READONLY` if `read_from_replicas` is set during initialization. + async def _on_connect(*args: Any) -> None: + """ + Initialize the connection, authenticate and select a database and send + `READONLY` if `read_from_replicas` is set during initialization. + """ - :meta private: - """ + if self.read_from_replicas: + assert (await (await self.create_request(b"READONLY", decode=False))) == b"OK" - await super().on_connect() - if self.read_from_replicas: - assert (await (await self.create_request(b"READONLY", decode=False))) == b"OK" + self.register_connect_callback(_on_connect) diff --git a/coredis/exceptions.py b/coredis/exceptions.py index 03012caaf..374f55009 100644 --- a/coredis/exceptions.py +++ b/coredis/exceptions.py @@ -1,6 +1,9 @@ from __future__ import annotations import re +from ssl import SSLError + +from anyio import BrokenResourceError, ConnectionFailed, EndOfStream from coredis.typing import RedisValueT @@ -358,3 +361,12 @@ class StreamConsumerInitializationError(StreamConsumerError): Raised when a stream consumer could not be initialized based on the configuration provided """ + + +RETRYABLE = ( + BrokenResourceError, + ConnectionError, + ConnectionFailed, + EndOfStream, + SSLError, +) diff --git a/coredis/modules/response/_callbacks/autocomplete.py b/coredis/modules/response/_callbacks/autocomplete.py index 8d978d34a..df8452ba2 100644 --- a/coredis/modules/response/_callbacks/autocomplete.py +++ b/coredis/modules/response/_callbacks/autocomplete.py @@ -7,7 +7,6 @@ class AutocompleteCallback( ResponseCallback[ - list[ResponseType], list[ResponseType], tuple[AutocompleteSuggestion[AnyStr], ...] | tuple[()], ] diff --git a/coredis/modules/response/_callbacks/graph.py b/coredis/modules/response/_callbacks/graph.py index a0f86cec2..82eb2b4aa 100644 --- a/coredis/modules/response/_callbacks/graph.py +++ b/coredis/modules/response/_callbacks/graph.py @@ -1,9 +1,9 @@ from __future__ import annotations -import asyncio import enum from typing import TYPE_CHECKING, Any +from coredis._concurrency import gather from coredis._utils import b, nativestr from coredis.modules.response.types import ( GraphNode, @@ -57,7 +57,7 @@ class RedisValueTypes(enum.IntEnum): class QueryCallback( - ResponseCallback[ResponseType, ResponseType, GraphQueryResult[AnyStr]], + ResponseCallback[ResponseType, GraphQueryResult[AnyStr]], Generic[AnyStr], ): properties: dict[int, StringT] @@ -91,7 +91,7 @@ async def pre_process( entity, max_label_id, max_relation_id, max_property_id ) if any(k != -1 for k in [max_label_id, max_relation_id, max_property_id]): - self.labels, self.relationships, self.properties = await asyncio.gather( + self.labels, self.relationships, self.properties = await gather( self.fetch_mapping(max_label_id, "labels", client), self.fetch_mapping(max_relation_id, "relationships", client), self.fetch_mapping(max_property_id, "properties", client), @@ -209,9 +209,7 @@ def parse_entity(self, entity): return GraphPath(nodes, relations) -class GraphSlowLogCallback( - ResponseCallback[ResponseType, ResponseType, tuple[GraphSlowLogInfo, ...]] -): +class GraphSlowLogCallback(ResponseCallback[ResponseType, tuple[GraphSlowLogInfo, ...]]): def transform( self, response: ResponseType, @@ -221,7 +219,6 @@ def transform( class ConfigGetCallback( ResponseCallback[ - ResponseType, ResponseType, ResponsePrimitive | dict[AnyStr, ResponsePrimitive], ] diff --git a/coredis/modules/response/_callbacks/json.py b/coredis/modules/response/_callbacks/json.py index fbe7c38f9..e7b02c852 100644 --- a/coredis/modules/response/_callbacks/json.py +++ b/coredis/modules/response/_callbacks/json.py @@ -7,7 +7,7 @@ from coredis.typing import JsonType, ResponseType -class JsonCallback(ResponseCallback[ResponseType, ResponseType, JsonType]): +class JsonCallback(ResponseCallback[ResponseType, JsonType]): def transform( self, response: ResponseType, diff --git a/coredis/modules/response/_callbacks/search.py b/coredis/modules/response/_callbacks/search.py index d1e0155e1..74a447c07 100644 --- a/coredis/modules/response/_callbacks/search.py +++ b/coredis/modules/response/_callbacks/search.py @@ -21,30 +21,23 @@ class SearchConfigCallback( ResponseCallback[ - list[list[ResponsePrimitive]], dict[AnyStr, ResponseType] | list[list[ResponsePrimitive]], dict[AnyStr, ResponsePrimitive], ] ): def transform( - self, - response: list[list[ResponsePrimitive]], - ) -> dict[AnyStr, ResponsePrimitive]: - command_arguments = [] - for item in response: - try: - v = (item[0], json.loads(item[1])) - except (ValueError, TypeError): - v = item - command_arguments.append(v) - return dict(command_arguments) - - def transform_3( self, response: dict[AnyStr, ResponseType] | list[list[ResponsePrimitive]], ) -> dict[AnyStr, ResponsePrimitive]: if isinstance(response, list): - return self.transform(response) + command_arguments = [] + for item in response: + try: + v = (item[0], json.loads(item[1])) + except (ValueError, TypeError): + v = item + command_arguments.append(v) + return dict(command_arguments) else: config = {} for item, value in response.items(): @@ -57,63 +50,56 @@ def transform_3( class SearchResultCallback( ResponseCallback[ - list[ResponseType], list[ResponseType] | dict[AnyStr, ResponseType], SearchResult[AnyStr], ] ): def transform( - self, - response: list[ResponseType], - ) -> SearchResult[AnyStr]: - if self.options.get("nocontent"): - return SearchResult[AnyStr]( - response[0], - tuple(SearchDocument(i, None, None, None, None, {}) for i in response[1:]), - ) - step = 2 - results = [] - score_idx = payload_idx = sort_key_idx = 0 - if self.options.get("withscores"): - score_idx = 1 - step += 1 - if self.options.get("withpayloads"): - payload_idx = score_idx + 1 - step += 1 - if self.options.get("withsortkeys"): - sort_key_idx = payload_idx + 1 - step += 1 - - for k in range(1, len(response) - 1, step): - section = response[k : k + step] - score_explain = None - if self.options.get("explainscore"): - score = section[score_idx][0] - score_explain = section[score_idx][1] - else: - score = section[score_idx] if score_idx else None - fields = EncodingInsensitiveDict(flat_pairs_to_dict(section[-1])) - if "$" in fields: - fields = json.loads(fields.pop("$")) - results.append( - SearchDocument( - section[0], - float(score) if score else None, - score_explain, - section[payload_idx] if payload_idx else None, - section[sort_key_idx] if sort_key_idx else None, - fields, - ) - ) - return SearchResult[AnyStr](response[0], tuple(results)) - - def transform_3( self, response: list[ResponseType] | dict[AnyStr, ResponseType], ) -> SearchResult[AnyStr]: results = [] if isinstance(response, list): - return self.transform(response) + if self.options.get("nocontent"): + return SearchResult[AnyStr]( + response[0], + tuple(SearchDocument(i, None, None, None, None, {}) for i in response[1:]), + ) + step = 2 + results = [] + score_idx = payload_idx = sort_key_idx = 0 + if self.options.get("withscores"): + score_idx = 1 + step += 1 + if self.options.get("withpayloads"): + payload_idx = score_idx + 1 + step += 1 + if self.options.get("withsortkeys"): + sort_key_idx = payload_idx + 1 + step += 1 + + for k in range(1, len(response) - 1, step): + section = response[k : k + step] + score_explain = None + if self.options.get("explainscore"): + score = section[score_idx][0] + score_explain = section[score_idx][1] + else: + score = section[score_idx] if score_idx else None + fields = EncodingInsensitiveDict(flat_pairs_to_dict(section[-1])) + if "$" in fields: + fields = json.loads(fields.pop("$")) + results.append( + SearchDocument( + section[0], + float(score) if score else None, + score_explain, + section[payload_idx] if payload_idx else None, + section[sort_key_idx] if sort_key_idx else None, + fields, + ) + ) + return SearchResult[AnyStr](response[0], tuple(results)) else: response = EncodingInsensitiveDict(response) for result in response["results"]: @@ -141,24 +127,11 @@ def transform_3( class AggregationResultCallback( ResponseCallback[ - list[ResponseType], dict[AnyStr, ResponseType] | list[ResponseType], SearchAggregationResult[AnyStr], ] ): def transform( - self, - response: list[ResponseType], - ) -> SearchAggregationResult: - return SearchAggregationResult[AnyStr]( - [ - flat_pairs_to_dict(k, partial(self.try_json, self.options)) - for k in (response[1:] if not self.options.get("with_cursor") else response[0][1:]) - ], - response[1] if self.options.get("with_cursor") else None, - ) - - def transform_3( self, response: dict[AnyStr, ResponseType] | list[ResponseType], ) -> SearchAggregationResult: @@ -180,7 +153,15 @@ def transform_3( cursor, ) else: - return self.transform(response) + return SearchAggregationResult[AnyStr]( + [ + flat_pairs_to_dict(k, partial(self.try_json, self.options)) + for k in ( + response[1:] if not self.options.get("with_cursor") else response[0][1:] + ) + ], + response[1] if self.options.get("with_cursor") else None, + ) @staticmethod def try_json(options, value): @@ -194,28 +175,21 @@ def try_json(options, value): class SpellCheckCallback( ResponseCallback[ - list[ResponseType], dict[AnyStr, ResponseType] | list[ResponseType], dict[AnyStr, OrderedDict[AnyStr, float]], ] ): def transform( - self, - response: list[ResponseType], - ) -> dict[AnyStr, OrderedDict[AnyStr, float]]: - return { - result[1]: OrderedDict( - (suggestion[1], float(suggestion[0])) for suggestion in result[2] - ) - for result in response - } - - def transform_3( self, response: dict[AnyStr, ResponseType] | list[ResponseType], ) -> dict[AnyStr, OrderedDict[AnyStr, float]]: # For older versions of redis search that didn't support RESP3 if isinstance(response, list): - return self.transform(response) + return { + result[1]: OrderedDict( + (suggestion[1], float(suggestion[0])) for suggestion in result[2] + ) + for result in response + } response = EncodingInsensitiveDict(response) return {key: OrderedDict(ChainMap(*result)) for key, result in response["results"].items()} diff --git a/coredis/modules/response/_callbacks/timeseries.py b/coredis/modules/response/_callbacks/timeseries.py index a92529f5d..0154b906c 100644 --- a/coredis/modules/response/_callbacks/timeseries.py +++ b/coredis/modules/response/_callbacks/timeseries.py @@ -20,7 +20,6 @@ class SampleCallback( ResponseCallback[ - list[RedisValueT], list[RedisValueT], tuple[int, float] | tuple[()], ] @@ -34,7 +33,6 @@ def transform( class SamplesCallback( ResponseCallback[ - list[list[RedisValueT]] | None, list[list[RedisValueT]] | None, tuple[tuple[int, float], ...] | tuple[()], ], @@ -66,7 +64,6 @@ def transform( class TimeSeriesCallback( ResponseCallback[ - ResponseType, ResponseType, dict[AnyStr, tuple[dict[AnyStr, AnyStr], tuple[int, float] | tuple[()]]], ] @@ -85,7 +82,6 @@ def transform( class TimeSeriesMultiCallback( ResponseCallback[ - ResponseType, ResponseType, dict[ AnyStr, @@ -99,30 +95,6 @@ def transform( ) -> dict[ AnyStr, tuple[dict[AnyStr, AnyStr], tuple[tuple[int, float], ...] | tuple[()]], - ]: - if self.options.get("grouped"): - return { - r[0]: ( - flat_pairs_to_dict(r[1][0]) if r[1] else {}, - tuple(SampleCallback().transform(t) for t in r[2]), - ) - for r in cast(Any, response) - } - else: - return { - r[0]: ( - dict(r[1]), - tuple(SampleCallback().transform(t) for t in r[2]), - ) - for r in cast(Any, response) - } - - def transform_3( - self, - response: ResponseType, - ) -> dict[ - AnyStr, - tuple[dict[AnyStr, AnyStr], tuple[tuple[int, float], ...] | tuple[()]], ]: if isinstance(response, dict): if self.options.get("grouped"): @@ -142,7 +114,22 @@ def transform_3( for k, r in response.items() } else: - return self.transform(response) + if self.options.get("grouped"): + return { + r[0]: ( + flat_pairs_to_dict(r[1][0]) if r[1] else {}, + tuple(SampleCallback().transform(t) for t in r[2]), + ) + for r in cast(Any, response) + } + else: + return { + r[0]: ( + dict(r[1]), + tuple(SampleCallback().transform(t) for t in r[2]), + ) + for r in cast(Any, response) + } class ClusterMergeTimeSeries(ClusterMergeMapping[AnyStr, tuple[Any, ...]]): diff --git a/coredis/parser.py b/coredis/parser.py index c545cd7c0..a7c396576 100644 --- a/coredis/parser.py +++ b/coredis/parser.py @@ -1,12 +1,14 @@ from __future__ import annotations -import asyncio +from abc import abstractmethod from collections.abc import Hashable from io import BytesIO from typing import cast -from coredis._protocols import ConnectionP -from coredis._utils import b +from anyio.streams.memory import MemoryObjectSendStream + +from coredis._enum import CaseAndEncodingInsensitiveEnum +from coredis._utils import b, logger from coredis.constants import SYM_CRLF, RESPDataType from coredis.exceptions import ( AskError, @@ -50,6 +52,40 @@ class NotEnoughData: NOT_ENOUGH_DATA: Final[NotEnoughData] = NotEnoughData() +class PubSubMessageTypes(CaseAndEncodingInsensitiveEnum): + MESSAGE = b"message" + PMESSAGE = b"pmessage" + SMESSAGE = b"smessage" + SUBSCRIBE = b"subscribe" + UNSUBSCRIBE = b"unsubscribe" + PSUBSCRIBE = b"psubscribe" + PUNSUBSCRIBE = b"punsubscribe" + SSUBSCRIBE = b"ssubscribe" + SUNSUBSCRIBE = b"sunsubscribe" + + +PUBLISH_MESSAGE_TYPES = { + PubSubMessageTypes.MESSAGE.value, + PubSubMessageTypes.PMESSAGE.value, + PubSubMessageTypes.SMESSAGE.value, +} +SUBUNSUB_MESSAGE_TYPES = { + PubSubMessageTypes.SUBSCRIBE.value, + PubSubMessageTypes.SSUBSCRIBE.value, + PubSubMessageTypes.PSUBSCRIBE.value, + PubSubMessageTypes.UNSUBSCRIBE.value, + PubSubMessageTypes.SUNSUBSCRIBE.value, + PubSubMessageTypes.PUNSUBSCRIBE.value, +} +UNSUBSCRIBE_MESSAGE_TYPES = { + PubSubMessageTypes.UNSUBSCRIBE.value, + PubSubMessageTypes.PUNSUBSCRIBE.value, + PubSubMessageTypes.SUNSUBSCRIBE.value, +} +INVALIDATION_TYPES = {b"invalidate"} +PUSH_MESSAGE_TYPES = PUBLISH_MESSAGE_TYPES | SUBUNSUB_MESSAGE_TYPES | INVALIDATION_TYPES + + class RESPNode: __slots__ = ("depth", "key", "node_type") depth: int @@ -66,8 +102,8 @@ def __init__( self.node_type = node_type self.key = key - def append(self, item: ResponseType) -> None: - raise NotImplementedError() + @abstractmethod + def append(self, item: ResponseType) -> None: ... def ensure_hashable(self, item: ResponseType) -> Hashable: if isinstance(item, (int, float, bool, str, bytes)): @@ -80,7 +116,7 @@ def ensure_hashable(self, item: ResponseType) -> Hashable: return tuple( (cast(ResponsePrimitive, k), self.ensure_hashable(v)) for k, v in item.items() ) - return item # noqa + return item class ListNode(RESPNode): @@ -164,8 +200,8 @@ class Parser: "WRONGTYPE": WrongTypeError, } - def __init__(self) -> None: - self.push_messages: asyncio.Queue[ResponseType] | None = None + def __init__(self, push_messages: MemoryObjectSendStream[list[ResponseType]]) -> None: + self.push_messages = push_messages self.localbuffer: BytesIO = BytesIO(b"") self.bytes_read: int = 0 self.bytes_written: int = 0 @@ -176,10 +212,6 @@ def feed(self, data: bytes) -> None: self.bytes_written += self.localbuffer.write(data) self.localbuffer.seek(self.bytes_read) - def on_connect(self, connection: ConnectionP) -> None: - """Called when the stream connects""" - self.push_messages = connection.push_messages - def on_disconnect(self) -> None: """Called when the stream disconnects""" if not self.localbuffer.closed: @@ -201,15 +233,13 @@ def get_response( self, decode: bool, encoding: str | None = None, - push_message_types: set[bytes] | None = None, ) -> NotEnoughData | ResponseType: """ :param decode: Whether to decode simple or bulk strings :param push_message_types: the push message types to return if they arrive. If a message arrives that does not match the filter, it will - be put on the :data:`~coredis.connection.BaseConnection.push_messages` - queue + be logged; otherwise, it will be put on the :data:`~coredis.connection.BaseConnection.push_messages` queue :return: The next available parsed response read from the connection. If there is not enough data on the wire a ``NotEnoughData`` instance will be returned. @@ -218,17 +248,14 @@ def get_response( response = self.parse(decode, encoding) if isinstance(response, NotEnoughData): return response - else: - if response and response.response_type == RESPDataType.PUSH: - assert isinstance(response.response, list) - assert self.push_messages - if not push_message_types or b(response.response[0]) not in push_message_types: - self.push_messages.put_nowait(response.response) - continue - else: - break + if response and response.response_type == RESPDataType.PUSH: + assert isinstance(response.response, list) + if b(response.response[0]) in PUSH_MESSAGE_TYPES: + self.push_messages.send_nowait(response.response) else: - break + logger.debug(f"Unhandled push message: {response.response}") + else: + break return response.response if response else None def parse( diff --git a/coredis/pipeline.py b/coredis/pipeline.py index 5bf3c785f..0af35dd3b 100644 --- a/coredis/pipeline.py +++ b/coredis/pipeline.py @@ -1,16 +1,14 @@ from __future__ import annotations -import asyncio import functools import inspect -import sys import textwrap -import warnings from abc import ABCMeta from concurrent.futures import CancelledError -from types import TracebackType -from typing import Any, cast +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, cast +from anyio import sleep from deprecated.sphinx import deprecated from coredis._utils import b, hash_slot, nativestr @@ -20,7 +18,12 @@ from coredis.commands.constants import CommandName, NodeFlag from coredis.commands.request import TransformedResponse from coredis.commands.script import Script -from coredis.connection import BaseConnection, ClusterConnection, CommandInvocation, Connection +from coredis.connection import ( + BaseConnection, + ClusterConnection, + CommandInvocation, + Request, +) from coredis.exceptions import ( AskError, ClusterCrossSlotError, @@ -36,12 +39,11 @@ TryAgainError, WatchError, ) -from coredis.pool import ClusterConnectionPool, ConnectionPool +from coredis.pool import ClusterConnectionPool from coredis.pool.nodemanager import ManagedNode from coredis.response._callbacks import ( AnyStrCallback, AsyncPreProcessingCallback, - BoolCallback, BoolsCallback, NoopCallback, SimpleStringCallback, @@ -103,6 +105,25 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]: return wrapper +class Awaitablize(Awaitable[T]): + __slots__ = ("_result",) + + def __init__(self, result: T) -> None: + self._result = result + + def __await__(self) -> Generator[Any, None, T]: + async def _coro() -> T: + await sleep(0) # checkpoint + return self._result + + # create the coroutine when awaited to avoid Python warning on GC + return _coro().__await__() + + +def await_result(result: T) -> Awaitable[T]: + return Awaitablize(result) + + class PipelineCommandRequest(CommandRequest[CommandResponseT]): """ Command request used within a pipeline. Handles immediate execution for WATCH or @@ -110,7 +131,6 @@ class PipelineCommandRequest(CommandRequest[CommandResponseT]): """ client: Pipeline[Any] | ClusterPipeline[Any] - queued_response: Awaitable[bytes | str] def __init__( self, @@ -153,12 +173,6 @@ def transform( parent=self, ) - async def __backward_compatibility_return(self) -> Pipeline[Any] | ClusterPipeline[Any]: - """ - For backward compatibility: returns the pipeline instance when awaited before execute(). - """ - return self.client - def __await__(self) -> Generator[None, None, CommandResponseT]: if hasattr(self, "response"): return self.response.__await__() @@ -171,18 +185,16 @@ async def _transformed() -> CommandResponseT: return self.callback(r) return _transformed().__await__() - else: - warnings.warn( - """ -Awaiting a pipeline command response before calling `execute()` on the pipeline instance -has no effect and returns the pipeline instance itself for backward compatibility. - -To add commands to a pipeline simply call the methods synchronously. The awaitable response -can be awaited after calling `execute()` to retrieve a statically typed response if required. - """, - stacklevel=2, - ) - return self.__backward_compatibility_return().__await__() # type: ignore[return-value] + exc = ResponseError( + "Result not set! Either a transaction failed, or you're awaiting a pipeline command before calling execute." + ) + if self.client._raise_on_error: + raise exc + + async def _get_exc() -> ResponseError: + return exc + + return _get_exc().__await__() # type: ignore class ClusterPipelineCommandRequest(PipelineCommandRequest[CommandResponseT]): @@ -200,7 +212,7 @@ def __init__( parent: CommandRequest[Any] | None = None, ) -> None: self.position: int = 0 - self.result: Any | None = None + self.result: Any = None self.asking: bool = False super().__init__( client, @@ -229,8 +241,8 @@ def __init__( self.commands: list[ClusterPipelineCommandRequest[Any]] = [] self.in_transaction = in_transaction self.timeout = timeout - self.multi_cmd: asyncio.Future[ResponseType] | None = None - self.exec_cmd: asyncio.Future[ResponseType] | None = None + self.multi_cmd: Request | None = None + self.exec_cmd: Request | None = None def extend(self, c: list[ClusterPipelineCommandRequest[Any]]) -> None: self.commands.extend(c) @@ -279,7 +291,6 @@ async def write(self) -> None: c.result = e async def read(self) -> None: - connection = self.connection success = True multi_result = None if self.multi_cmd: @@ -312,8 +323,8 @@ async def read(self) -> None: await c.callback.pre_process(self.client, transaction_result[idx]) c.result = c.callback( transaction_result[idx], - version=connection.protocol_version, ) + c.response = await_result(c.result) elif isinstance(multi_result, BaseException): raise multi_result @@ -370,42 +381,44 @@ class Pipeline(Client[AnyStr], metaclass=PipelineMeta): and its instance is placed into the response list returned by :meth:`execute` """ - command_stack: list[PipelineCommandRequest[Any]] - connection_pool: ConnectionPool + QUEUED_RESPONSES = {b"QUEUED", "QUEUED"} def __init__( self, client: Client[AnyStr], transaction: bool | None, - watches: Parameters[KeyT] | None = None, + raise_on_error: bool = True, timeout: float | None = None, ) -> None: self.client: Client[AnyStr] = client - self.connection_pool = client.connection_pool - self.connection: Connection | None = None + self._connection: BaseConnection | None = None self._transaction = transaction + self._raise_on_error = raise_on_error self.watching = False - self.watches: Parameters[KeyT] | None = watches or None - self.command_stack = [] + self.command_stack: list[PipelineCommandRequest[Any]] = [] self.cache = None self.explicit_transaction = False self.scripts: set[Script[AnyStr]] = set() self.timeout = timeout self.type_adapter = client.type_adapter - async def __aenter__(self) -> Pipeline[AnyStr]: - return await self.get_instance() + def __repr__(self) -> str: + return f"{type(self).__name__}<{repr(self._connection)}>" - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - await self.clear() + @property + def connection(self) -> BaseConnection: + if not self._connection: + raise RedisError( + "Pipeline not initialized correctly! Make sure to use await or the async context manager." + ) + return self._connection - def __await__(self) -> Generator[Any, Any, Pipeline[AnyStr]]: - return self.get_instance().__await__() + @asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + pool = self.client.connection_pool + async with pool.acquire() as self._connection: + yield self + await self._execute() def __len__(self) -> int: return len(self.command_stack) @@ -413,9 +426,6 @@ def __len__(self) -> int: def __bool__(self) -> bool: return True - async def get_instance(self) -> Pipeline[AnyStr]: - return self - def create_request( self, name: bytes, @@ -432,37 +442,18 @@ def create_request( async def clear(self) -> None: """ - Clear the pipeline, reset state, and release the connection back to the pool. + Clear the pipeline and reset state. """ self.command_stack.clear() self.scripts = set() # Reset connection state if we were watching something. if self.watching and self.connection: - try: - request = await self.connection.create_request(CommandName.UNWATCH, decode=False) - await request - except ConnectionError: - self.connection.disconnect() + await (await self.connection.create_request(CommandName.UNWATCH, decode=False)) + else: + await sleep(0) # checkpoint # Reset pipeline state and release connection if needed. self.watching = False - self.watches = [] self.explicit_transaction = False - if self.connection: - self.connection_pool.release(self.connection) - self.connection = None - - #: :meta private: - reset_pipeline = clear - - @deprecated( - "The reset method in pipelines clashes with the redis ``RESET`` command. Use :meth:`clear` instead", - "5.0.0", - ) - def reset(self) -> CommandRequest[None]: - """ - Deprecated. Use :meth:`clear` instead. - """ - return self.clear() # type: ignore def multi(self) -> None: """ @@ -470,11 +461,34 @@ def multi(self) -> None: """ if self.explicit_transaction: raise RedisError("Cannot issue nested calls to MULTI") - if self.command_stack: raise RedisError("Commands without an initial WATCH have already been issued") self.explicit_transaction = True + async def watch(self, *keys: KeyT) -> bool: + """ + Watch the given keys for changes. Switches to immediate execution mode + until :meth:`multi` is called. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return await self.immediate_execute_command( + RedisCommand(name=CommandName.WATCH, arguments=keys), + callback=SimpleStringCallback(), + ) + + async def unwatch(self) -> bool: + """ + Remove all key watches and return to buffered mode. + """ + if not self.watching: + await sleep(0) # checkpoint + return False + return await self.immediate_execute_command( + RedisCommand(name=CommandName.UNWATCH, arguments=()), + callback=SimpleStringCallback(), + ) + def execute_command( self, command: RedisCommandP, @@ -497,33 +511,24 @@ async def immediate_execute_command( :meta private: """ - conn = self.connection - # if this is the first call, we need a connection - if not conn: - conn = await self.connection_pool.get_connection() - self.connection = conn try: - request = await conn.create_request( + request = await self.connection.create_request( command.name, *command.arguments, decode=kwargs.get("decode") ) return callback( await request, - version=conn.protocol_version, ) except (ConnectionError, TimeoutError): - conn.disconnect() - # if we're not already watching, we can safely retry the command try: if not self.watching: - request = await conn.create_request( + request = await self.connection.create_request( command.name, *command.arguments, decode=kwargs.get("decode") ) - return callback(await request, version=conn.protocol_version) + return callback(await request) raise except ConnectionError: # the retry failed so cleanup. - conn.disconnect() await self.clear() raise finally: @@ -547,11 +552,10 @@ async def _execute_transaction( self, connection: BaseConnection, commands: list[PipelineCommandRequest[Any]], - raise_on_error: bool, ) -> tuple[Any, ...]: - multi_cmd = await connection.create_request(CommandName.MULTI, timeout=self.timeout) requests = await connection.create_requests( - [ + [CommandInvocation(CommandName.MULTI, (), None, None)] + + [ CommandInvocation( cmd.name, cmd.arguments, @@ -563,67 +567,51 @@ async def _execute_transaction( None, ) for cmd in commands - ], + ] + + [CommandInvocation(CommandName.EXEC, (), None, None)], timeout=self.timeout, ) - exec_cmd = await connection.create_request(CommandName.EXEC, timeout=self.timeout) - for i, cmd in enumerate(commands): - cmd.queued_response = cast(Awaitable[StringT], requests[i]) errors: list[tuple[int, RedisError | None]] = [] - multi_failed = False - # parse off the response for MULTI # NOTE: we need to handle ResponseErrors here and continue # so that we read all the additional command messages from # the socket try: - await multi_cmd - except RedisError: - multi_failed = True - errors.append((0, cast(RedisError, sys.exc_info()[1]))) + await requests[0] + except RedisError as e: + errors.append((0, e)) # and all the other commands for i, cmd in enumerate(commands): try: - if cmd.queued_response: - assert (await cmd.queued_response) in {b"QUEUED", "QUEUED"} - except RedisError: - ex = cast(RedisError, sys.exc_info()[1]) - self.annotate_exception(ex, i + 1, cmd.name, cmd.arguments) - errors.append((i, ex)) - - response: list[ResponseType] - try: - response = cast( - list[ResponseType], - await exec_cmd if exec_cmd else None, - ) - except (ExecAbortError, ResponseError): - if self.explicit_transaction and not multi_failed: - await self.immediate_execute_command( - RedisCommand(name=CommandName.DISCARD, arguments=()), callback=BoolCallback() - ) + if (resp := await requests[i + 1]) not in self.QUEUED_RESPONSES: + raise Exception( + f"Abnormal response in pipeline for command {cmd.name!r}: {resp!r}" + ) + except RedisError as e: + self.annotate_exception(e, i + 1, cmd.name, cmd.arguments) + errors.append((i + 1, e)) + try: + response = cast(list[ResponseType] | None, await requests[-1]) + except (ExecAbortError, ResponseError) as e: if errors and errors[0][1]: - raise errors[0][1] + raise errors[0][1] from e raise if response is None: raise WatchError("Watched variable changed.") # put any parse errors into the response - - for i, e in errors: + for i, e in errors: # type: ignore response.insert(i, cast(ResponseType, e)) if len(response) != len(commands): - if self.connection: - self.connection.disconnect() raise ResponseError("Wrong number of response items from pipeline execution") # find any errors in the response and raise if necessary - if raise_on_error: + if self._raise_on_error: self.raise_first_error(commands, response) # We have to run response callbacks manually @@ -632,17 +620,13 @@ async def _execute_transaction( if not isinstance(r, Exception): if isinstance(cmd.callback, AsyncPreProcessingCallback): await cmd.callback.pre_process(self.client, r) - r = cmd.callback(r, version=connection.protocol_version, **cmd.execution_parameters) - cmd.response = asyncio.get_running_loop().create_future() - cmd.response.set_result(r) + r = cmd.callback(r, **cmd.execution_parameters) + cmd.response = await_result(r) data.append(r) return tuple(data) async def _execute_pipeline( - self, - connection: BaseConnection, - commands: list[PipelineCommandRequest[Any]], - raise_on_error: bool, + self, connection: BaseConnection, commands: list[PipelineCommandRequest[Any]] ) -> tuple[Any, ...]: # build up all commands into a single request to increase network perf requests = await connection.create_requests( @@ -672,17 +656,14 @@ async def _execute_pipeline( await cmd.callback.pre_process(self.client, res, **cmd.execution_parameters) resp = cmd.callback( res, - version=connection.protocol_version, **cmd.execution_parameters, ) - cmd.response = asyncio.get_event_loop().create_future() - cmd.response.set_result(resp) + cmd.response = await_result(resp) response.append(resp) except ResponseError as re: - cmd.response = asyncio.get_event_loop().create_future() - cmd.response.set_exception(re) - response.append(sys.exc_info()[1]) - if raise_on_error: + cmd.response = await_result(re) + response.append(re) + if self._raise_on_error: self.raise_first_error(commands, response) return tuple(response) @@ -712,30 +693,27 @@ def annotate_exception( async def load_scripts(self) -> None: # make sure all scripts that are about to be run on this pipeline exist scripts = list(self.scripts) - immediate = self.immediate_execute_command shas = [s.sha for s in scripts] # we can't use the normal script_* methods because they would just # get buffered in the pipeline. - exists = await immediate( + exists = await self.immediate_execute_command( RedisCommand(CommandName.SCRIPT_EXISTS, tuple(shas)), callback=BoolsCallback() ) if not all(exists): for s, exist in zip(scripts, exists): if not exist: - s.sha = await immediate( + s.sha = await self.immediate_execute_command( RedisCommand(CommandName.SCRIPT_LOAD, (s.script,)), callback=AnyStrCallback[AnyStr](), ) - async def execute(self, raise_on_error: bool = True) -> tuple[Any, ...]: + async def _execute(self) -> None: """ - Execute all queued commands in the pipeline. Returns a tuple of results. + Execute all queued commands in the pipeline. """ - stack = self.command_stack - - if not stack: - return () + if not self.command_stack: + return None if self.scripts: await self.load_scripts() @@ -745,50 +723,22 @@ async def execute(self, raise_on_error: bool = True) -> tuple[Any, ...]: else: exec = self._execute_pipeline - conn = self.connection - - if not conn: - conn = await self.connection_pool.get_connection() - # assign to self.connection so clear() releases the connection - # back to the pool after we're done - self.connection = conn - try: - return await exec(conn, stack, raise_on_error) - except (ConnectionError, TimeoutError, CancelledError): - conn.disconnect() - + await exec(self.connection, self.command_stack) + except (ConnectionError, TimeoutError, CancelledError) as e: # if we were watching a variable, the watch is no longer valid # since this connection has died. raise a WatchError, which # indicates the user should retry his transaction. If this is more # than a temporary failure, the WATCH that the user next issues # will fail, propegating the real ConnectionError - if self.watching: - raise WatchError("A ConnectionError occured on while watching one or more keys") - # otherwise, it's safe to retry since the transaction isn't - # predicated on any state - - return await exec(conn, stack, raise_on_error) + raise WatchError( + "A connection error occured while watching one or more keys" + ) from e + raise finally: await self.clear() - def watch(self, *keys: KeyT) -> CommandRequest[bool]: - """ - Watch the given keys for changes. Switches to immediate execution mode - until :meth:`multi` is called. - """ - if self.explicit_transaction: - raise RedisError("Cannot issue a WATCH after a MULTI") - - return self.create_request(CommandName.WATCH, *keys, callback=SimpleStringCallback()) - - def unwatch(self) -> CommandRequest[bool]: - """ - Remove all key watches and return to buffered mode. - """ - return self.create_request(CommandName.UNWATCH, callback=SimpleStringCallback()) - class ClusterPipeline(Client[AnyStr], metaclass=ClusterPipelineMeta): """ @@ -810,7 +760,8 @@ class ClusterPipeline(Client[AnyStr], metaclass=ClusterPipelineMeta): def __init__( self, client: RedisCluster[AnyStr], - transaction: bool | None = False, + raise_on_error: bool = True, + transaction: bool = False, watches: Parameters[KeyT] | None = None, timeout: float | None = None, ) -> None: @@ -819,6 +770,7 @@ def __init__( self.client = client self.connection_pool = client.connection_pool self.result_callbacks = client.result_callbacks + self._raise_on_error = raise_on_error self._transaction = transaction self._watched_node: ManagedNode | None = None self._watched_connection: ClusterConnection | None = None @@ -868,30 +820,16 @@ async def unwatch(self) -> bool: self._watched_connection = None return True - def __del__(self) -> None: - if self._watched_connection: - self.connection_pool.release(self._watched_connection) - def __len__(self) -> int: return len(self.command_stack) def __bool__(self) -> bool: return True - def __await__(self) -> Generator[None, None, Self]: - yield - return self - - async def __aenter__(self) -> ClusterPipeline[AnyStr]: - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - await self.clear() + @asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + yield self + await self._execute() def execute_command( self, @@ -929,7 +867,7 @@ def annotate_exception( msg = f"Command # {number} ({cmd} {args}) of pipeline caused error: {exception.args[0]}" exception.args = (msg,) + exception.args[1:] - async def execute(self, raise_on_error: bool = True) -> tuple[object, ...]: + async def _execute(self) -> tuple[object, ...]: """ Execute all queued commands in the cluster pipeline. Returns a tuple of results. """ @@ -943,7 +881,7 @@ async def execute(self, raise_on_error: bool = True) -> tuple[object, ...]: else: execute = self.send_cluster_commands try: - return await execute(raise_on_error) + return await execute(self._raise_on_error) finally: await self.clear() @@ -1068,7 +1006,6 @@ async def send_cluster_commands( # Release all connections back to the pool only if safe (no unread buffer). # If an error occurred, do not release to avoid buffer mismatches. for n in nodes.values(): - protocol_version = n.connection.protocol_version self.connection_pool.release(n.connection) # Retry MOVED/ASK/connection errors one by one if allowed. @@ -1076,7 +1013,6 @@ async def send_cluster_commands( (c for c in attempt if isinstance(c.result, ERRORS_ALLOW_RETRY)), key=lambda x: x.position, ) - if attempt and allow_redirections: await self.connection_pool.nodes.increment_reinitialize_counter(len(attempt)) for c in attempt: @@ -1094,12 +1030,11 @@ async def send_cluster_commands( if not isinstance(c.result, RedisError): if isinstance(c.callback, AsyncPreProcessingCallback): await c.callback.pre_process(self.client, c.result) - r = c.callback(c.result, version=protocol_version) + r = c.callback(c.result) + c.response = await_result(r) response.append(r) - if raise_on_error: self.raise_first_error() - return tuple(response) def _determine_slot( @@ -1168,22 +1103,21 @@ async def immediate_execute_command( return callback( await request, - version=conn.protocol_version, ) except (ConnectionError, TimeoutError): - conn.disconnect() + # conn.disconnect() try: if not self.watching: request = await conn.create_request( command.name, *command.arguments, decode=kwargs.get("decode") ) - return callback(await request, version=conn.protocol_version) + return callback(await request) else: raise except ConnectionError: # the retry failed so cleanup. - conn.disconnect() + # conn.disconnect() await self.clear() raise finally: @@ -1213,7 +1147,6 @@ async def _watch(self, node: ManagedNode, conn: BaseConnection, keys: Parameters return SimpleStringCallback()( cast(StringT, await request), - version=conn.protocol_version, ) async def _unwatch(self, conn: BaseConnection) -> bool: diff --git a/coredis/pool/__init__.py b/coredis/pool/__init__.py index 0a64b6994..f7bba95b6 100644 --- a/coredis/pool/__init__.py +++ b/coredis/pool/__init__.py @@ -1,11 +1,6 @@ from __future__ import annotations -from .basic import BlockingConnectionPool, ConnectionPool -from .cluster import BlockingClusterConnectionPool, ClusterConnectionPool +from .basic import ConnectionPool +from .cluster import ClusterConnectionPool -__all__ = [ - "ConnectionPool", - "BlockingConnectionPool", - "ClusterConnectionPool", - "BlockingClusterConnectionPool", -] +__all__ = ["ConnectionPool", "ClusterConnectionPool"] diff --git a/coredis/pool/basic.py b/coredis/pool/basic.py index 008d6ac09..03a7cf825 100644 --- a/coredis/pool/basic.py +++ b/coredis/pool/basic.py @@ -1,16 +1,21 @@ from __future__ import annotations -import asyncio -import os -import threading -import time import warnings -from itertools import chain +from collections import deque +from contextlib import asynccontextmanager from ssl import SSLContext, VerifyMode -from typing import Any, cast +from typing import Any, AsyncGenerator, cast from urllib.parse import parse_qs, unquote, urlparse -import async_timeout +from anyio import ( + TASK_STATUS_IGNORED, + AsyncContextManagerMixin, + Semaphore, + create_task_group, + fail_after, +) +from anyio.abc import TaskStatus +from typing_extensions import Self from coredis._utils import query_param_to_bool from coredis.connection import ( @@ -19,14 +24,15 @@ RedisSSLContext, UnixDomainSocketConnection, ) -from coredis.exceptions import ConnectionError -from coredis.typing import Callable, ClassVar, RedisValueT, TypeVar +from coredis.typing import Callable, ClassVar, TypeVar _CPT = TypeVar("_CPT", bound="ConnectionPool") -class ConnectionPool: - """Generic connection pool""" +class ConnectionPool(AsyncContextManagerMixin): + """ + Generic connection pool + """ #: Mapping of querystring arguments to their parser functions URL_QUERY_ARGUMENT_PARSERS: ClassVar[ @@ -37,8 +43,6 @@ class ConnectionPool: "connect_timeout": float, "max_connections": int, "max_idle_time": int, - "protocol_version": int, - "idle_check_interval": int, "noreply": bool, "noevict": bool, "notouch": bool, @@ -180,14 +184,16 @@ def from_url( return cls(**kwargs) + def __repr__(self) -> str: + return f"{type(self).__name__}<{self.connection_class.describe(self.connection_kwargs)}>" + def __init__( self, *, - connection_class: type[Connection] | None = None, + connection_class: type[BaseConnection] | None = None, max_connections: int | None = None, - max_idle_time: int = 0, - idle_check_interval: int = 1, - **connection_kwargs: Any | None, + timeout: float | None = None, + **connection_kwargs: Any, ) -> None: """ Creates a connection pool. If :paramref:`max_connections` is set, then this @@ -201,253 +207,52 @@ def __init__( """ self.connection_class = connection_class or Connection self.connection_kwargs = connection_kwargs - self.max_connections = max_connections or 2**31 - self.max_idle_time = max_idle_time - self.idle_check_interval = idle_check_interval - self.initialized = False - self.reset() + self.max_connections = max_connections or 64 + self.timeout = timeout self.decode_responses = bool(self.connection_kwargs.get("decode_responses", False)) self.encoding = str(self.connection_kwargs.get("encoding", "utf-8")) - - async def initialize(self) -> None: - self.initialized = True - - def __repr__(self) -> str: - return f"{type(self).__name__}<{self.connection_class.describe(self.connection_kwargs)}>" - - def __del__(self) -> None: - self.disconnect() - - async def disconnect_on_idle_time_exceeded(self, connection: Connection) -> None: - while True: - if ( - time.time() - connection.last_active_at > self.max_idle_time - and not connection.requests_pending - ): - connection.disconnect() - if connection in self._available_connections: - self._available_connections.remove(connection) - self._created_connections -= 1 - break - await asyncio.sleep(self.idle_check_interval) - - def reset(self) -> None: - self.pid = os.getpid() - self._created_connections = 0 - self._available_connections: list[Connection] = [] - self._in_use_connections: set[Connection] = set() - self._check_lock = threading.Lock() - - def checkpid(self) -> None: # noqa - if self.pid != os.getpid(): - with self._check_lock: - # Double check - if self.pid == os.getpid(): - return - self.disconnect() - self.reset() - - def peek_available(self) -> BaseConnection | None: - return self._available_connections[0] if self._available_connections else None - - async def get_connection( - self, - command_name: bytes | None = None, - *args: RedisValueT, - acquire: bool = True, - **kwargs: RedisValueT | None, - ) -> Connection: - """Gets a connection from the pool""" - self.checkpid() + self._used_connections: set[BaseConnection] = set() + self._free_connections: deque[BaseConnection] = deque() + self._capacity = Semaphore(self.max_connections, max_value=self.max_connections) + + @asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + async with create_task_group() as tg: + self._task_group = tg + yield self + self._task_group.cancel_scope.cancel() + self._free_connections.clear() + self._used_connections.clear() + + async def wrap_connection( + self, connection: BaseConnection, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED + ) -> None: try: - connection = self._available_connections.pop() - if connection.is_connected and connection.needs_handshake: - await connection.perform_handshake() - except IndexError: - if self._created_connections >= self.max_connections: - raise ConnectionError("Too many connections") - connection = self._make_connection(**kwargs) - - if acquire: - self._in_use_connections.add(connection) - else: - self._available_connections.append(connection) - - return connection - - def release(self, connection: Connection) -> None: + await connection.run(task_status=task_status) + finally: + if connection in self._used_connections: + self._used_connections.remove(connection) + elif connection in self._free_connections: + self._free_connections.remove(connection) + + @asynccontextmanager + async def acquire(self) -> AsyncGenerator[BaseConnection]: """ - Releases the :paramref:`connection` back to the pool + Gets a dedicated connection from the pool, or creates a new one if all are busy. """ - self.checkpid() - - if connection.pid == self.pid: - self._in_use_connections.remove(connection) - self._available_connections.append(connection) - - def disconnect(self) -> None: - """Closes all connections in the pool""" - all_conns = chain(self._available_connections, self._in_use_connections) - - for connection in all_conns: - connection.disconnect() - self._created_connections -= 1 - - def _make_connection(self, **options: RedisValueT | None) -> Connection: - """ - Creates a new connection - """ - - self._created_connections += 1 - connection = self.connection_class( - **self.connection_kwargs, # type: ignore - ) - - if self.max_idle_time > self.idle_check_interval > 0: - # do not await the future - asyncio.ensure_future(self.disconnect_on_idle_time_exceeded(connection)) - - return connection - - -class BlockingConnectionPool(ConnectionPool): - """ - Blocking connection pool:: - - >>> from coredis import Redis - >>> client = Redis(connection_pool=BlockingConnectionPool()) - - It performs the same function as the default - :class:`~coredis.ConnectionPool`, in that, it maintains a pool of reusable - connections that can be shared by multiple redis clients. - - The difference is that, in the event that a client tries to get a - connection from the pool when all of the connections are in use, rather than - raising a :exc:`~coredis.ConnectionError` (as the default - :class:`~coredis.ConnectionPool` implementation does), it - makes the client blocks for a specified number of seconds until - a connection becomes available. - - Use :paramref:`max_connections` to increase / decrease the pool size:: - - >>> pool = BlockingConnectionPool(max_connections=10) - - Use :paramref:`timeout` to tell it either how many seconds to wait for a - connection to become available, or to block forever:: - - >>> # Block forever. - >>> pool = BlockingConnectionPool(timeout=None) - >>> # Raise a ``ConnectionError`` after five seconds if a connection is - >>> # not available. - >>> pool = BlockingConnectionPool(timeout=5) - """ - - def __init__( - self, - connection_class: type[Connection] | None = None, - queue_class: type[asyncio.Queue[Connection | None]] = asyncio.LifoQueue, - max_connections: int | None = None, - timeout: int = 20, - max_idle_time: int = 0, - idle_check_interval: int = 1, - **connection_kwargs: RedisValueT | None, - ): - self.timeout = timeout - self.queue_class = queue_class - self.total_wait = 0 - self.total_allocated = 0 - max_connections = max_connections or 50 - - super().__init__( - connection_class=connection_class or Connection, - max_connections=max_connections, - max_idle_time=max_idle_time, - idle_check_interval=idle_check_interval, - **connection_kwargs, - ) - - async def disconnect_on_idle_time_exceeded(self, connection: Connection) -> None: - while True: - if time.time() - connection.last_active_at > self.max_idle_time: - # Unlike the non blocking pool, we don't free the connection object, - # but always reuse it - connection.disconnect() - - break - await asyncio.sleep(self.idle_check_interval) - - def reset(self) -> None: - self._pool: asyncio.Queue[Connection | None] = self.queue_class(self.max_connections) - - while True: - try: - self._pool.put_nowait(None) - except asyncio.QueueFull: - break - - super().reset() - - def peek_available(self) -> BaseConnection | None: - return ( - self._pool._queue[-1] # type: ignore - if (self._pool and not self._pool.empty()) - else None - ) - - async def get_connection( - self, - command_name: bytes | None = None, - *args: RedisValueT, - acquire: bool = True, - **kwargs: RedisValueT | None, - ) -> Connection: - """Gets a connection from the pool""" - self.checkpid() - - try: - async with async_timeout.timeout(self.timeout): - connection = await self._pool.get() - if connection and connection.is_connected and connection.needs_handshake: - await connection.perform_handshake() - except asyncio.TimeoutError: - raise ConnectionError("No connection available.") - if connection is None: - connection = self._make_connection() - - if acquire: - self._in_use_connections.add(connection) + with fail_after(self.timeout): + await self._capacity.acquire() + if self._free_connections: + connection = self._free_connections.pop() else: - self._pool.put_nowait(connection) - - return connection - - def release(self, connection: Connection) -> None: - """Releases the connection back to the pool""" - _connection: Connection | None = connection - - self.checkpid() - - if _connection and _connection.pid == self.pid: - self._in_use_connections.remove(_connection) - try: - self._pool.put_nowait(_connection) - except asyncio.QueueFull: - _connection.disconnect() - - def disconnect(self) -> None: - """Closes all connections in the pool""" - pooled_connections: list[Connection | None] = [] - - while True: - try: - pooled_connections.append(self._pool.get_nowait()) - except asyncio.QueueEmpty: - break - for conn in pooled_connections: - self._pool.put_nowait(conn) - - all_conns = chain(pooled_connections, self._in_use_connections) - - for connection in all_conns: - if connection is not None: - connection.disconnect() + connection = self.connection_class(**self.connection_kwargs) + await self._task_group.start(self.wrap_connection, connection) + self._used_connections.add(connection) + try: + yield connection + finally: + self._capacity.release() + # if we're here there wasn't an error + if connection in self._used_connections: + self._used_connections.remove(connection) + self._free_connections.append(connection) diff --git a/coredis/pool/cluster.py b/coredis/pool/cluster.py index 94e9c1905..e40354913 100644 --- a/coredis/pool/cluster.py +++ b/coredis/pool/cluster.py @@ -1,17 +1,18 @@ from __future__ import annotations -import asyncio import os import random import threading -import time import warnings -from typing import Any, cast +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, cast -import async_timeout +from anyio import Lock, fail_after +from typing_extensions import Self +from coredis._concurrency import Queue, QueueEmpty, QueueFull from coredis._utils import b, hash_slot -from coredis.connection import ClusterConnection, Connection +from coredis.connection import BaseConnection, ClusterConnection, Connection from coredis.exceptions import ConnectionError, RedisClusterException from coredis.globals import READONLY_COMMANDS from coredis.pool.basic import ConnectionPool @@ -40,21 +41,18 @@ class ClusterConnectionPool(ConnectionPool): "reinitialize_steps": int, "skip_full_coverage_check": bool, "read_from_replicas": bool, - "blocking": bool, } nodes: NodeManager connection_class: type[ClusterConnection] _created_connections_per_node: dict[str, int] - _cluster_available_connections: dict[str, asyncio.Queue[Connection | None]] - _cluster_in_use_connections: dict[str, set[Connection]] + _cluster_available_connections: dict[str, Queue[Connection]] def __init__( self, startup_nodes: Iterable[Node] | None = None, connection_class: type[ClusterConnection] = ClusterConnection, - queue_class: type[asyncio.Queue[Connection | None]] = asyncio.LifoQueue, max_connections: int | None = None, max_connections_per_node: bool = False, reinitialize_steps: int | None = None, @@ -62,11 +60,8 @@ def __init__( nodemanager_follow_cluster: bool = True, readonly: bool = False, read_from_replicas: bool = False, - max_idle_time: int = 0, - idle_check_interval: int = 1, - blocking: bool = False, timeout: int = 20, - **connection_kwargs: Any | None, + **connection_kwargs: Any, ): """ @@ -82,13 +77,9 @@ def __init__( :param max_connections: Maximum number of connections to allow concurrently from this client. If the value is ``None`` it will default to 32. :param max_connections_per_node: Whether to use the value of :paramref:`max_connections` - on a per node basis or cluster wide. If ``False`` and :paramref:`blocking` is ``True`` - the per-node connection pools will have a maximum size of :paramref:`max_connections` - divided by the number of nodes in the cluster. - :param blocking: If ``True`` the client will block at most :paramref:`timeout` seconds - if :paramref:`max_connections` is reachd when trying to obtain a connection - :param timeout: Number of seconds to block if :paramref:`block` is ``True`` when trying to - obtain a connection. + on a per node basis or cluster wide. If ``False`` the per-node connection pools will have + a maximum size of :paramref:`max_connections` divided by the number of nodes in the cluster. + :param timeout: Number of seconds to block when trying to obtain a connection. :param skip_full_coverage_check: Skips the check of cluster-require-full-coverage config, useful for clusters without the :rediscommand:`CONFIG` command (For example with AWS Elasticache) @@ -100,24 +91,18 @@ def __init__( """ super().__init__( connection_class=connection_class, - max_connections=max_connections, - max_idle_time=max_idle_time, - idle_check_interval=idle_check_interval, **connection_kwargs, ) - self.queue_class = queue_class - # Special case to make from_url method compliant with cluster setting. - # from_url method will send in the ip and port through a different variable then the - # regular startup_nodes variable. + self.initialized = False if startup_nodes is None: host = connection_kwargs.pop("host", None) port = connection_kwargs.pop("port", None) + if host and port: startup_nodes = [Node(host=str(host), port=int(port))] - self.blocking = blocking - self.blocking_timeout = timeout - self.max_connections = max_connections or 2**31 + self.timeout = timeout + self.max_connections = max_connections or 64 self.max_connections_per_node = max_connections_per_node self.nodes = NodeManager( startup_nodes, @@ -125,18 +110,16 @@ def __init__( skip_full_coverage_check=skip_full_coverage_check, max_connections=self.max_connections, nodemanager_follow_cluster=nodemanager_follow_cluster, - **connection_kwargs, # type: ignore + **connection_kwargs, ) self.connection_kwargs = connection_kwargs self.connection_kwargs["read_from_replicas"] = read_from_replicas self.read_from_replicas = read_from_replicas or readonly - self.max_idle_time = max_idle_time - self.idle_check_interval = idle_check_interval self.reset() if "stream_timeout" not in self.connection_kwargs: self.connection_kwargs["stream_timeout"] = None - self._init_lock = asyncio.Lock() + self._init_lock = Lock() def __repr__(self) -> str: """ @@ -151,57 +134,42 @@ def __repr__(self) -> str: ), ) + @asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + async with super().__asynccontextmanager__(): + await self.initialize() + try: + yield self + finally: + self.reset() + async def initialize(self) -> None: if not self.initialized: async with self._init_lock: - if not self.initialized: - await self.nodes.initialize() - if not self.max_connections_per_node and self.max_connections < len( - self.nodes.nodes - ): - warnings.warn( - f"The value of max_connections={self.max_connections} " - "should be atleast equal to the number of nodes " - f"({len(self.nodes.nodes)}) in the cluster and has been increased by " - f"{len(self.nodes.nodes) - self.max_connections} connections." - ) - self.max_connections = len(self.nodes.nodes) - await super().initialize() - - async def disconnect_on_idle_time_exceeded(self, connection: Connection) -> None: - assert isinstance(connection, ClusterConnection) - while True: - if ( - time.time() - connection.last_active_at > self.max_idle_time - and not connection.requests_pending - ): - connection.disconnect() - node = connection.node - if node.name in self._created_connections_per_node: - self._created_connections_per_node[node.name] -= 1 - break - await asyncio.sleep(self.idle_check_interval) + if self.initialized: + return + await self.nodes.initialize() + + if not self.max_connections_per_node and self.max_connections < len( + self.nodes.nodes + ): + warnings.warn( + f"The value of max_connections={self.max_connections} " + "should be atleast equal to the number of nodes " + f"({len(self.nodes.nodes)}) in the cluster and has been increased by " + f"{len(self.nodes.nodes) - self.max_connections} connections." + ) + self.max_connections = len(self.nodes.nodes) + self.initialized = True def reset(self) -> None: """Resets the connection pool back to a clean state""" self.pid = os.getpid() self._created_connections_per_node = {} self._cluster_available_connections = {} - self._cluster_in_use_connections = {} self._check_lock = threading.Lock() self.initialized = False - def checkpid(self) -> None: # noqa - if self.pid != os.getpid(): - with self._check_lock: - if self.pid == os.getpid(): - # another thread already did the work while we waited - # on the lockself. - - return - self.disconnect() - self.reset() - async def get_connection( self, command_name: bytes | None = None, @@ -221,30 +189,29 @@ async def get_connection( return await self.get_random_connection() slot = hash_slot(b(routing_key)) + if node_type == "replica": node = self.get_replica_node_by_slot(slot) else: node = self.get_primary_node_by_slot(slot) - self.checkpid() try: connection = self.__node_pool(node.name).get_nowait() - except asyncio.QueueEmpty: + except QueueEmpty: connection = None - if not connection: - connection = self._make_node_connection(node) + + if not connection or not connection.is_connected: + connection = await self._make_node_connection(node) else: if connection.is_connected and connection.needs_handshake: await connection.perform_handshake() - if acquire: - self._cluster_in_use_connections.setdefault(node.name, set()) - self._cluster_in_use_connections[node.name].add(connection) - else: + if not acquire: self.__node_pool(node.name).put_nowait(connection) + return connection - def _make_node_connection(self, node: ManagedNode) -> Connection: + async def _make_node_connection(self, node: ManagedNode) -> Connection: """Creates a new connection to a node""" if self.count_all_num_connections(node) >= self.max_connections: @@ -260,97 +227,41 @@ def _make_node_connection(self, node: ManagedNode) -> Connection: connection = self.connection_class( host=node.host, port=node.port, - **self.connection_kwargs, # type: ignore + **self.connection_kwargs, ) - - # Must store node in the connection to make it eaiser to track + await self._task_group.start(connection.run) + # Must store node in the connection to make it easier to track connection.node = node - if self.max_idle_time > self.idle_check_interval > 0: - # do not await the future - asyncio.ensure_future(self.disconnect_on_idle_time_exceeded(connection)) - return connection - def __node_pool(self, node: str) -> asyncio.Queue[Connection | None]: + def __node_pool(self, node: str) -> Queue[Connection]: if not self._cluster_available_connections.get(node): self._cluster_available_connections[node] = self.__default_node_queue() + return self._cluster_available_connections[node] def __default_node_queue( self, - ) -> asyncio.Queue[Connection | None]: + ) -> Queue[Connection]: q_size = max( 1, - int( - self.max_connections - if self.max_connections_per_node - else self.max_connections / len(self.nodes.nodes) - ), + self.max_connections + if self.max_connections_per_node + else self.max_connections // len(self.nodes.nodes), ) - q: asyncio.Queue[Connection | None] = self.queue_class(q_size) - - # If the queue is non-blocking, we don't need to pre-populate it - if not self.blocking: - return q - - if q_size > 2**16: # noqa - raise RuntimeError( - f"Requested unsupported value of max_connections: {q_size} in blocking mode" - ) + return Queue[Connection](q_size) - while True: - try: - q.put_nowait(None) - except asyncio.QueueFull: - break - return q - - def release(self, connection: Connection) -> None: + def release(self, connection: BaseConnection) -> None: """Releases the connection back to the pool""" assert isinstance(connection, ClusterConnection) - self.checkpid() - if connection.pid == self.pid: - # Remove the current connection from _in_use_connection and add it back to the available - # pool. There is cases where the connection is to be removed but it will not exist and - # there must be a safe way to remove - i_c = self._cluster_in_use_connections.get(connection.node.name, set()) - - if connection in i_c: - i_c.remove(connection) - else: - pass try: self.__node_pool(connection.node.name).put_nowait(connection) - except asyncio.QueueFull: - connection.disconnect() - # reduce node connection count in case of too many connection error raised - if connection.node.name in self._created_connections_per_node: - self._created_connections_per_node[connection.node.name] -= 1 - - def disconnect(self) -> None: - """Closes all connections in the pool""" - for node_connections in self._cluster_in_use_connections.values(): - for connection in node_connections: - connection.disconnect() - for node, available_connections in self._cluster_available_connections.items(): - removed = 0 - while True: - try: - _connection = available_connections.get_nowait() - if _connection: - _connection.disconnect() - if node in self._created_connections_per_node: - self._created_connections_per_node[node] -= 1 - removed += 1 - except asyncio.QueueEmpty: - break - # Refill queue with empty slots - for _ in range(removed): - available_connections.put_nowait(None) + except QueueFull: + pass def count_all_num_connections(self, node: ManagedNode) -> int: if self.max_connections_per_node: @@ -360,16 +271,15 @@ def count_all_num_connections(self, node: ManagedNode) -> int: async def get_random_connection(self, primary: bool = False) -> ClusterConnection: """Opens new connection to random redis server in the cluster""" + for node in self.nodes.random_startup_node_iter(primary): connection = await self.get_connection_by_node(node) + if connection: return connection raise RedisClusterException("Cant reach a single startup node.") async def get_connection_by_key(self, key: StringT) -> ClusterConnection: - if not key: - raise RedisClusterException("No way to dispatch this command to Redis Cluster.") - return await self.get_connection_by_slot(hash_slot(b(key))) async def get_connection_by_slot(self, slot: int) -> ClusterConnection: @@ -377,8 +287,6 @@ async def get_connection_by_slot(self, slot: int) -> ClusterConnection: Determines what server a specific slot belongs to and return a redis object that is connected """ - self.checkpid() - try: return await self.get_connection_by_node(self.get_node_by_slot(slot)) except KeyError: @@ -386,24 +294,12 @@ async def get_connection_by_slot(self, slot: int) -> ClusterConnection: async def get_connection_by_node(self, node: ManagedNode) -> ClusterConnection: """Gets a connection by node""" - self.checkpid() + with fail_after(self.timeout): + connection = await self.__node_pool(node.name).get() - if not self.blocking: - try: - connection = self.__node_pool(node.name).get_nowait() - except asyncio.QueueEmpty: - connection = None - else: - try: - async with async_timeout.timeout(self.blocking_timeout): - connection = await self.__node_pool(node.name).get() - except asyncio.TimeoutError: - raise ConnectionError("No connection available.") - - if not connection: - connection = self._make_node_connection(node) + if not connection or not connection.is_connected: + connection = await self._make_node_connection(node) - self._cluster_in_use_connections.setdefault(node.name, set()).add(connection) return cast(ClusterConnection, connection) def get_primary_node_by_slot(self, slot: int) -> ManagedNode: @@ -411,6 +307,7 @@ def get_primary_node_by_slot(self, slot: int) -> ManagedNode: def get_primary_node_by_slots(self, slots: list[int]) -> ManagedNode: nodes = {self.nodes.slots[slot][0].node_id for slot in slots} + if len(nodes) == 1: return self.nodes.slots[slots[0]][0] else: @@ -423,8 +320,10 @@ def get_replica_node_by_slots( self, slots: list[int], replica_only: bool = False ) -> ManagedNode: nodes = {self.nodes.slots[slot][0].node_id for slot in slots} + if len(nodes) == 1: slot = slots[0] + if replica_only: return random.choice( [node for node in self.nodes.slots[slot] if node.server_type != "primary"] @@ -437,81 +336,11 @@ def get_replica_node_by_slots( def get_node_by_slot(self, slot: int, command: bytes | None = None) -> ManagedNode: if self.read_from_replicas and command in READONLY_COMMANDS: return self.get_replica_node_by_slot(slot) + return self.get_primary_node_by_slot(slot) def get_node_by_slots(self, slots: list[int], command: bytes | None = None) -> ManagedNode: if self.read_from_replicas and command in READONLY_COMMANDS: return self.get_replica_node_by_slots(slots) - return self.get_primary_node_by_slots(slots) - - -class BlockingClusterConnectionPool(ClusterConnectionPool): - """ - .. versionadded:: 4.3.0 - Blocking connection pool for :class:`~coredis.RedisCluster` client - - .. note:: This is just a convenience subclass of :class:`~coredis.pool.ClusterConnectionPool` - that sets :paramref:`~coredis.pool.ClusterConnectionPool.blocking` to ``True`` - """ - - def __init__( - self, - startup_nodes: Iterable[Node] | None = None, - connection_class: type[ClusterConnection] = ClusterConnection, - queue_class: type[asyncio.Queue[Connection | None]] = asyncio.LifoQueue, - max_connections: int | None = None, - max_connections_per_node: bool = False, - reinitialize_steps: int | None = None, - skip_full_coverage_check: bool = False, - nodemanager_follow_cluster: bool = True, - readonly: bool = False, - read_from_replicas: bool = False, - max_idle_time: int = 0, - idle_check_interval: int = 1, - timeout: int = 20, - **connection_kwargs: Any | None, - ): - """ - - Changes - - .. versionchanged:: 4.4.0 - - - :paramref:`nodemanager_follow_cluster` now defaults to ``True`` - - - .. deprecated:: 4.4.0 - - - :paramref:`readonly` renamed to :paramref:`read_from_replicas` - - :param max_connections: Maximum number of connections to allow concurrently from this - client. - :param max_connections_per_node: Whether to use the value of :paramref:`max_connections` - on a per node basis or cluster wide. If ``False`` the per-node connection pools will have - a maximum size of :paramref:`max_connections` divided by the number of nodes in the - cluster. - :param timeout: Number of seconds to block when trying to obtain a connection. - :param skip_full_coverage_check: - Skips the check of cluster-require-full-coverage config, useful for clusters - without the CONFIG command (like aws) - :param nodemanager_follow_cluster: - The node manager will during initialization try the last set of nodes that - it was operating on. This will allow the client to drift along side the cluster - if the cluster nodes move around alot. - """ - super().__init__( - startup_nodes=startup_nodes, - connection_class=connection_class, - queue_class=queue_class, - max_connections=max_connections, - max_connections_per_node=max_connections_per_node, - reinitialize_steps=reinitialize_steps, - skip_full_coverage_check=skip_full_coverage_check, - nodemanager_follow_cluster=nodemanager_follow_cluster, - readonly=readonly, - read_from_replicas=read_from_replicas, - max_idle_time=max_idle_time, - idle_check_interval=idle_check_interval, - timeout=timeout, - blocking=True, - **connection_kwargs, - ) + return self.get_primary_node_by_slots(slots) diff --git a/coredis/pool/nodemanager.py b/coredis/pool/nodemanager.py index a6e3410d1..0b048040e 100644 --- a/coredis/pool/nodemanager.py +++ b/coredis/pool/nodemanager.py @@ -56,7 +56,7 @@ def __init__( skip_full_coverage_check: bool = False, nodemanager_follow_cluster: bool = True, decode_responses: bool = False, - **connection_kwargs: Any | None, + **connection_kwargs: Any, ) -> None: """ :skip_full_coverage_check: @@ -151,10 +151,9 @@ def get_redis_link(self, host: str, port: int) -> Redis[Any]: "ssl_context", "parser_class", "loop", - "protocol_version", ) connection_kwargs = {k: v for k, v in self.connection_kwargs.items() if k in allowed_keys} - return Redis(host=host, port=port, **connection_kwargs) # type: ignore + return Redis(host=host, port=port, **connection_kwargs) async def initialize(self) -> None: """ @@ -185,9 +184,9 @@ async def initialize(self) -> None: cluster_slots = {} try: if node: - r = self.get_redis_link(host=node.host, port=node.port) - cluster_slots = await r.cluster_slots() - self.startup_nodes_reachable = True + async with self.get_redis_link(host=node.host, port=node.port) as r: + cluster_slots = await r.cluster_slots() + self.startup_nodes_reachable = True except RedisError as err: startup_node_errors.setdefault(str(err), []).append(node.name) continue @@ -288,9 +287,9 @@ async def increment_reinitialize_counter(self, ct: int = 1) -> None: async def node_require_full_coverage(self, node: ManagedNode) -> bool: try: - r_node = self.get_redis_link(host=node.host, port=node.port) - node_config = await r_node.config_get(["cluster-require-full-coverage"]) - return "yes" in node_config.values() + async with self.get_redis_link(host=node.host, port=node.port) as r_node: + node_config = await r_node.config_get(["cluster-require-full-coverage"]) + return "yes" in node_config.values() except ResponseError as err: warnings.warn( "Unable to determine whether the cluster requires full coverage " @@ -335,6 +334,3 @@ def populate_startup_nodes(self) -> None: self.startup_nodes.clear() for n in self.nodes.values(): self.startup_nodes.append(n) - - async def reset(self) -> None: - await self.initialize() diff --git a/coredis/recipes/__init__.py b/coredis/recipes/__init__.py index e69de29bb..41ba78d4f 100644 --- a/coredis/recipes/__init__.py +++ b/coredis/recipes/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from .credentials import ElastiCacheIAMProvider +from .lock import Lock + +__all__ = ["ElastiCacheIAMProvider", "Lock"] diff --git a/coredis/recipes/credentials/iam_provider.py b/coredis/recipes/credentials.py similarity index 100% rename from coredis/recipes/credentials/iam_provider.py rename to coredis/recipes/credentials.py diff --git a/coredis/recipes/credentials/__init__.py b/coredis/recipes/credentials/__init__.py deleted file mode 100644 index 97fef59f2..000000000 --- a/coredis/recipes/credentials/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .iam_provider import ElastiCacheIAMProvider - -__all__ = ["ElastiCacheIAMProvider"] diff --git a/coredis/recipes/locks/lua_lock.py b/coredis/recipes/lock.py similarity index 90% rename from coredis/recipes/locks/lua_lock.py rename to coredis/recipes/lock.py index 6ede0e93a..3cb8706f5 100644 --- a/coredis/recipes/locks/lua_lock.py +++ b/coredis/recipes/lock.py @@ -1,15 +1,16 @@ from __future__ import annotations -import asyncio import contextvars -import importlib.resources import math import time import uuid import warnings +from pathlib import Path from types import TracebackType from typing import cast +from anyio import sleep + from coredis.client import Redis, RedisCluster from coredis.commands import Script from coredis.exceptions import ( @@ -22,13 +23,11 @@ from coredis.tokens import PureToken from coredis.typing import AnyStr, Generic, KeyT, StringT -with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - EXTEND_SCRIPT = Script(script=importlib.resources.read_text(__package__, "extend.lua")) - RELEASE_SCRIPT = Script(script=importlib.resources.read_text(__package__, "release.lua")) +EXTEND_SCRIPT = Script(script=(Path(__file__).parent / "lua/extend.lua").read_text()) +RELEASE_SCRIPT = Script(script=(Path(__file__).parent / "lua/release.lua").read_text()) -class LuaLock(Generic[AnyStr]): +class Lock(Generic[AnyStr]): """ A shared, distributed Lock using LUA scripts. @@ -40,31 +39,22 @@ class LuaLock(Generic[AnyStr]): import asyncio import coredis from coredis.exceptions import LockError - from coredis.recipes.locks import LuaLock - async def test(): - client = coredis.Redis() - async with LuaLock(client, "mylock", timeout=1.0): + client = coredis.Redis() + async with client: + async with client.lock("mylock", timeout=1.0): # do stuff await asyncio.sleep(0.5) # lock is implictly released when the context manager exits try: - async with LuaLock(client, "mylock", timeout=1.0): + async with client.lock("mylock", timeout=1.0): # do stuff that takes too long await asyncio.sleep(1) # lock will raise upon exiting the context manager except LockError as err: # roll back stuff print(f"Expected error: {err}") - lock = LuaLock(client, "mylock", timeout=1.0) - await lock.acquire() - # do stuff - await asyncio.sleep(0.5) - # do more stuff - await lock.extend(1.0) - await lock.release() - - asyncio.run(test()) + """ @classmethod @@ -129,7 +119,7 @@ def __init__( async def __aenter__( self, - ) -> LuaLock[AnyStr]: + ) -> Lock[AnyStr]: if await self.acquire(): return self raise LockAcquisitionError("Could not acquire lock") @@ -173,7 +163,7 @@ async def acquire( if stop_trying_at is not None and time.time() > stop_trying_at: return False - await asyncio.sleep(self.sleep) + await sleep(self.sleep) async def release(self) -> None: """ diff --git a/coredis/recipes/locks/__init__.py b/coredis/recipes/locks/__init__.py deleted file mode 100644 index f64ed2ffd..000000000 --- a/coredis/recipes/locks/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .lua_lock import LuaLock - -__all__ = ["LuaLock"] diff --git a/coredis/recipes/locks/extend.lua b/coredis/recipes/lua/extend.lua similarity index 99% rename from coredis/recipes/locks/extend.lua rename to coredis/recipes/lua/extend.lua index dac34baa2..9bd890db4 100644 --- a/coredis/recipes/locks/extend.lua +++ b/coredis/recipes/lua/extend.lua @@ -2,6 +2,7 @@ -- ARGS[1] - token -- ARGS[2] - additional milliseconds -- return 1 if the locks time was extended, otherwise 0 + local token = redis.call('get', KEYS[1]) if not token or token ~= ARGV[1] then return 0 diff --git a/coredis/recipes/locks/release.lua b/coredis/recipes/lua/release.lua similarity index 100% rename from coredis/recipes/locks/release.lua rename to coredis/recipes/lua/release.lua diff --git a/coredis/response/_callbacks/__init__.py b/coredis/response/_callbacks/__init__.py index 7bdcf4b76..3a1f2be21 100644 --- a/coredis/response/_callbacks/__init__.py +++ b/coredis/response/_callbacks/__init__.py @@ -18,7 +18,6 @@ Generic, Hashable, Iterable, - Literal, Mapping, ParamSpec, Protocol, @@ -39,7 +38,6 @@ CR_co = TypeVar("CR_co", covariant=True) CK_co = TypeVar("CK_co", covariant=True) -RESP = TypeVar("RESP") RESP3 = TypeVar("RESP3") if TYPE_CHECKING: @@ -52,7 +50,6 @@ def __new__( ) -> ResponseCallbackMeta: kls = super().__new__(cls, name, bases, namespace) setattr(kls, "transform", add_runtime_checks(getattr(kls, "transform"))) - setattr(kls, "transform_3", add_runtime_checks(getattr(kls, "transform_3"))) return kls @@ -62,37 +59,27 @@ def __new__( ) -> ClusterCallbackMeta: kls = super().__new__(cls, name, bases, namespace) setattr(kls, "combine", add_runtime_checks(getattr(kls, "combine"))) - setattr(kls, "combine_3", add_runtime_checks(getattr(kls, "combine_3"))) return kls -class ResponseCallback(ABC, Generic[RESP, RESP3, R], metaclass=ResponseCallbackMeta): - version: Literal[2, 3] - +class ResponseCallback(ABC, Generic[RESP3, R], metaclass=ResponseCallbackMeta): def __init__(self, **options: Any) -> None: self.options = options def __call__( self, - response: RESP | RESP3 | ResponseError, - version: Literal[2, 3] = 2, + response: RESP3 | ResponseError, ) -> R: - self.version = version if isinstance(response, ResponseError): exc_to_response = self.handle_exception(response) if exc_to_response: return exc_to_response - if version == 3: - return self.transform_3(cast(RESP3, response)) - return self.transform(cast(RESP, response)) + return self.transform(response) @abstractmethod - def transform(self, response: RESP) -> R: + def transform(self, response: RESP3) -> R: pass - def transform_3(self, response: RESP3) -> R: - return self.transform(cast(RESP, response)) - def handle_exception(self, exc: BaseException) -> R | None: return exc # type: ignore @@ -102,7 +89,7 @@ class AsyncPreProcessingCallback(Protocol): async def pre_process(self, client: Client[Any], response: ResponseType) -> None: ... -class NoopCallback(ResponseCallback[R, R, R]): +class NoopCallback(ResponseCallback[R, R]): def transform(self, response: R) -> R: return response @@ -111,10 +98,7 @@ class ClusterMultiNodeCallback(ABC, Generic[R], metaclass=ClusterCallbackMeta): def __call__( self, responses: Mapping[str, R | ResponseError], - version: int = 2, ) -> R: - if version == 3: - return self.combine_3(responses) return self.combine(responses) @property @@ -125,9 +109,6 @@ def response_policy(self) -> str: ... def combine(self, responses: Mapping[str, R], **options: Any) -> R: pass - def combine_3(self, responses: Mapping[str, R], **options: Any) -> R: - return self.combine(responses, **options) - @classmethod def raise_any(cls, values: Iterable[R]) -> None: for value in values: @@ -250,7 +231,7 @@ def response_policy(self) -> str: return "the concatenations of the results" -class SimpleStringCallback(ResponseCallback[StringT | None, StringT | None, bool]): +class SimpleStringCallback(ResponseCallback[StringT | None, bool]): def __init__( self, raise_on_error: type[Exception] | None = None, @@ -276,14 +257,14 @@ def transform(self, response: StringT | None, **options: Any) -> bool: return success -class IntCallback(ResponseCallback[int, int, int]): +class IntCallback(ResponseCallback[int, int]): def transform(self, response: ResponsePrimitive, **options: Any) -> int: if isinstance(response, int): return response raise ValueError(f"Unable to map {response!r} to int") -class AnyStrCallback(ResponseCallback[StringT, StringT, AnyStr]): +class AnyStrCallback(ResponseCallback[StringT, AnyStr]): def transform(self, response: StringT, **options: Any) -> AnyStr: if isinstance(response, (bytes, str)): return cast(AnyStr, response) @@ -291,7 +272,7 @@ def transform(self, response: StringT, **options: Any) -> AnyStr: raise ValueError(f"Unable to map {response!r} to AnyStr") -class FloatCallback(ResponseCallback[StringT | int | float, StringT | int | float, float]): +class FloatCallback(ResponseCallback[StringT | int | float, float]): def transform(self, response: ResponseType, **options: Any) -> float: if isinstance(response, float): return response @@ -301,14 +282,14 @@ def transform(self, response: ResponseType, **options: Any) -> float: raise ValueError(f"Unable to map {response} to float") -class BoolCallback(ResponseCallback[int | bool, int | bool, bool]): +class BoolCallback(ResponseCallback[int | bool, bool]): def transform(self, response: ResponseType, **options: Any) -> bool: if isinstance(response, bool): return response return bool(response) -class SimpleStringOrIntCallback(ResponseCallback[RedisValueT, RedisValueT, bool | int]): +class SimpleStringOrIntCallback(ResponseCallback[RedisValueT, bool | int]): def transform(self, response: RedisValueT, **options: Any) -> bool | int: if isinstance(response, (int, bool)): return response @@ -317,7 +298,7 @@ def transform(self, response: RedisValueT, **options: Any) -> bool | int: raise ValueError(f"Unable to map {response!r} to bool") -class TupleCallback(ResponseCallback[list[ResponseType], list[ResponseType], tuple[CR_co, ...]]): +class TupleCallback(ResponseCallback[list[ResponseType], tuple[CR_co, ...]]): def transform(self, response: ResponseType, **options: Any) -> tuple[CR_co, ...]: if isinstance(response, list): return cast(tuple[CR_co, ...], tuple(response)) @@ -326,7 +307,6 @@ def transform(self, response: ResponseType, **options: Any) -> tuple[CR_co, ...] class ItemOrTupleCallback( ResponseCallback[ - list[ResponseType] | ResponsePrimitive, list[ResponseType] | ResponsePrimitive, tuple[CR_co, ...] | CR_co, ] @@ -339,19 +319,19 @@ def transform( return cast(CR_co, response) -class MixedTupleCallback(ResponseCallback[list[ResponseType], list[ResponseType], tuple[R, S]]): +class MixedTupleCallback(ResponseCallback[list[ResponseType], tuple[R, S]]): def transform(self, response: ResponseType, **options: Any) -> tuple[R, S]: if isinstance(response, list): return cast(tuple[R, S], tuple(response)) raise ValueError(f"Unable to map {response!r} to tuple") -class ListCallback(ResponseCallback[list[ResponseType], list[ResponseType], list[CR_co]]): +class ListCallback(ResponseCallback[list[ResponseType], list[CR_co]]): def transform(self, response: list[ResponseType], **options: Any) -> list[CR_co]: return cast(list[CR_co], response) -class DateTimeCallback(ResponseCallback[int | float, int | float, datetime.datetime]): +class DateTimeCallback(ResponseCallback[int | float, datetime.datetime]): def transform( self, response: int | float, @@ -364,7 +344,6 @@ def transform( class DictCallback( ResponseCallback[ - Sequence[ResponseType] | dict[ResponsePrimitive, ResponseType], Sequence[ResponseType] | dict[ResponsePrimitive, ResponseType], dict[CK_co, CR_co], ] @@ -384,7 +363,9 @@ def transform( response: Sequence[ResponseType] | dict[ResponsePrimitive, ResponseType], **options: Any, ) -> dict[CK_co, CR_co]: - if isinstance(response, list): + if isinstance(response, dict): + return cast(dict[CK_co, CR_co], response) + elif isinstance(response, list): if self.flat: if self.recursive: return cast(dict[CK_co, CR_co], self.recursive_transformer(response)) @@ -395,15 +376,6 @@ def transform( return dict(r for r in response) raise ValueError(f"Unable to map {response!r} to mapping") - def transform_3( - self, - response: Sequence[ResponseType] | dict[ResponsePrimitive, ResponseType], - **options: Any, - ) -> dict[CK_co, CR_co]: - if isinstance(response, dict): - return cast(dict[CK_co, CR_co], response) - return self.transform(response, **options) - def recursive_transformer( self, item: Sequence[ResponseType] | dict[ResponsePrimitive, ResponseType] ) -> dict[CK_co, CR_co] | list[CK_co] | list[CR_co] | tuple[CK_co, ...] | tuple[CR_co, ...]: @@ -428,8 +400,7 @@ def recursive_transformer( class SetCallback( ResponseCallback[ - list[ResponsePrimitive], - set[ResponsePrimitive], + list[ResponsePrimitive] | set[ResponsePrimitive], set[CR_co], ] ): @@ -440,24 +411,15 @@ def transform( ) -> set[CR_co]: if isinstance(response, list): return cast(set[CR_co], set(response)) - raise ValueError(f"Unable to map {response} to set") - - def transform_3( - self, - response: list[ResponsePrimitive] | set[ResponsePrimitive], - **options: Any, - ) -> set[CR_co]: - if isinstance(response, set): + elif isinstance(response, set): return cast(set[CR_co], response) - else: - return self.transform(response) + raise ValueError(f"Unable to map {response} to set") class OneOrManyCallback( ResponseCallback[ CR_co | list[CR_co | None] | None, CR_co | list[CR_co | None] | None, - CR_co | list[CR_co | None] | None, ] ): def transform( @@ -468,14 +430,14 @@ def transform( return response -class BoolsCallback(ResponseCallback[ResponseType, ResponseType, tuple[bool, ...]]): +class BoolsCallback(ResponseCallback[ResponseType, tuple[bool, ...]]): def transform(self, response: ResponseType, **options: Any) -> tuple[bool, ...]: if isinstance(response, list): return tuple(BoolCallback()(r) for r in response) return () -class FloatsCallback(ResponseCallback[ResponseType, ResponseType, tuple[float, ...]]): +class FloatsCallback(ResponseCallback[ResponseType, tuple[float, ...]]): def transform(self, response: ResponseType, **options: Any) -> tuple[float, ...]: if isinstance(response, list): return tuple(FloatCallback()(r) for r in response) @@ -484,7 +446,6 @@ def transform(self, response: ResponseType, **options: Any) -> tuple[float, ...] class OptionalFloatCallback( ResponseCallback[ - StringT | int | float | None, StringT | int | float | None, float | None, ] @@ -499,7 +460,7 @@ def transform( return FloatCallback()(response) -class OptionalIntCallback(ResponseCallback[int | None, int | None, int | None]): +class OptionalIntCallback(ResponseCallback[int | None, int | None]): def transform(self, response: int | None, **options: Any) -> int | None: if response is None: return None @@ -512,7 +473,6 @@ class OptionalAnyStrCallback( ResponseCallback[ StringT | None, AnyStr | None, - AnyStr | None, ] ): def transform(self, response: StringT | None, **options: Any) -> AnyStr | None: @@ -523,14 +483,12 @@ def transform(self, response: StringT | None, **options: Any) -> AnyStr | None: raise ValueError(f"Unable to map {response} to AnyStr") -class OptionalListCallback( - ResponseCallback[list[ResponseType], list[ResponseType], list[CR_co] | None] -): +class OptionalListCallback(ResponseCallback[list[ResponseType], list[CR_co] | None]): def transform(self, response: ResponseType, **options: Any) -> list[CR_co] | None: return cast(list[CR_co], response) -class FirstValueCallback(ResponseCallback[list[CR_co], list[CR_co], CR_co]): +class FirstValueCallback(ResponseCallback[list[CR_co], CR_co]): def transform(self, response: list[CR_co], **options: Any) -> CR_co: if response: return response[0] diff --git a/coredis/response/_callbacks/acl.py b/coredis/response/_callbacks/acl.py index 3e7714e54..33204562d 100644 --- a/coredis/response/_callbacks/acl.py +++ b/coredis/response/_callbacks/acl.py @@ -1,6 +1,6 @@ from __future__ import annotations -from coredis.response._callbacks import DictCallback, ResponseCallback +from coredis.response._callbacks import ResponseCallback from coredis.typing import ( AnyStr, ResponsePrimitive, @@ -11,21 +11,10 @@ class ACLLogCallback( ResponseCallback[ list[Sequence[ResponsePrimitive] | None], - list[dict[AnyStr, ResponsePrimitive] | None], tuple[dict[AnyStr, ResponsePrimitive] | None, ...], ] ): def transform( - self, - response: list[Sequence[ResponsePrimitive] | None], - ) -> tuple[dict[AnyStr, ResponsePrimitive] | None, ...]: - return tuple( - DictCallback[AnyStr, ResponsePrimitive]()(r, version=self.version) - for r in response - if r - ) - - def transform_3( self, response: list[dict[AnyStr, ResponsePrimitive] | None], ) -> tuple[dict[AnyStr, ResponsePrimitive] | None, ...]: diff --git a/coredis/response/_callbacks/cluster.py b/coredis/response/_callbacks/cluster.py index 21d4a8bf5..9d32cffbc 100644 --- a/coredis/response/_callbacks/cluster.py +++ b/coredis/response/_callbacks/cluster.py @@ -1,8 +1,7 @@ from __future__ import annotations -from coredis._utils import EncodingInsensitiveDict, nativestr +from coredis._utils import nativestr from coredis.response._callbacks import ResponseCallback -from coredis.response._utils import flat_pairs_to_dict from coredis.response.types import ClusterNode, ClusterNodeDetail from coredis.typing import ( AnyStr, @@ -13,27 +12,15 @@ ) -class ClusterLinksCallback( - ResponseCallback[ResponseType, ResponseType, list[dict[AnyStr, ResponsePrimitive]]] -): +class ClusterLinksCallback(ResponseCallback[ResponseType, list[dict[AnyStr, ResponsePrimitive]]]): def transform( self, response: ResponseType, - ) -> list[dict[AnyStr, ResponsePrimitive]]: - transformed: list[dict[AnyStr, ResponsePrimitive]] = [] - - for item in response: - transformed.append(flat_pairs_to_dict(item)) - return transformed - - def transform_3( - self, - response: ResponseType, ) -> list[dict[AnyStr, ResponsePrimitive]]: return response -class ClusterInfoCallback(ResponseCallback[ResponseType, ResponseType, dict[str, str]]): +class ClusterInfoCallback(ResponseCallback[ResponseType, dict[str, str]]): def transform( self, response: ResponseType, @@ -43,7 +30,7 @@ def transform( class ClusterSlotsCallback( - ResponseCallback[ResponseType, ResponseType, dict[tuple[int, int], tuple[ClusterNode, ...]]] + ResponseCallback[ResponseType, dict[tuple[int, int], tuple[ClusterNode, ...]]] ): def transform( self, @@ -68,7 +55,7 @@ def parse_node(self, node: list[int | str]) -> ClusterNode: ) -class ClusterNodesCallback(ResponseCallback[ResponseType, ResponseType, list[ClusterNodeDetail]]): +class ClusterNodesCallback(ResponseCallback[ResponseType, list[ClusterNodeDetail]]): def transform( self, response: ResponseType, @@ -155,7 +142,6 @@ def parse_slots(s: str) -> tuple[list[int], list[dict[str, RedisValueT]]]: class ClusterShardsCallback( ResponseCallback[ - ResponseType, ResponseType, list[dict[AnyStr, list[RedisValueT] | Mapping[AnyStr, RedisValueT]]], ] @@ -163,21 +149,5 @@ class ClusterShardsCallback( def transform( self, response: ResponseType, - ) -> list[dict[AnyStr, list[RedisValueT] | Mapping[AnyStr, RedisValueT]]]: - shard_mapping: list[dict[AnyStr, list[RedisValueT] | Mapping[AnyStr, RedisValueT]]] = [] - - for shard in response: - transformed = EncodingInsensitiveDict(flat_pairs_to_dict(shard)) - node_mapping: list[dict[AnyStr, RedisValueT]] = [] - for node in transformed["nodes"]: - node_mapping.append(flat_pairs_to_dict(node)) - - transformed["nodes"] = node_mapping - shard_mapping.append(transformed.__wrapped__) # type: ignore - return shard_mapping - - def transform_3( - self, - response: ResponseType, ) -> list[dict[AnyStr, list[RedisValueT] | Mapping[AnyStr, RedisValueT]]]: return response diff --git a/coredis/response/_callbacks/command.py b/coredis/response/_callbacks/command.py index 8abc1824b..0d913081e 100644 --- a/coredis/response/_callbacks/command.py +++ b/coredis/response/_callbacks/command.py @@ -1,8 +1,7 @@ from __future__ import annotations -from coredis._utils import EncodingInsensitiveDict, nativestr +from coredis._utils import nativestr from coredis.response._callbacks import ResponseCallback -from coredis.response._utils import flat_pairs_to_dict from coredis.response.types import Command from coredis.typing import ( AnyStr, @@ -11,7 +10,7 @@ ) -class CommandCallback(ResponseCallback[list[ResponseType], list[ResponseType], dict[str, Command]]): +class CommandCallback(ResponseCallback[list[ResponseType], dict[str, Command]]): def transform( self, response: list[ResponseType], @@ -49,9 +48,7 @@ def transform( return commands -class CommandKeyFlagCallback( - ResponseCallback[list[ResponseType], list[ResponseType], dict[AnyStr, set[AnyStr]]] -): +class CommandKeyFlagCallback(ResponseCallback[list[ResponseType], dict[AnyStr, set[AnyStr]]]): def transform( self, response: list[ResponseType], @@ -61,25 +58,11 @@ def transform( class CommandDocCallback( ResponseCallback[ - list[ResponseType], dict[ResponsePrimitive, ResponseType], dict[AnyStr, dict[AnyStr, ResponseType]], ] ): def transform( - self, - response: list[ResponseType], - ) -> dict[AnyStr, dict[AnyStr, ResponseType]]: - cmd_mapping = flat_pairs_to_dict(response) - for cmd, doc in cmd_mapping.items(): - cmd_mapping[cmd] = EncodingInsensitiveDict(flat_pairs_to_dict(doc)) - cmd_mapping[cmd]["arguments"] = [ - flat_pairs_to_dict(arg) for arg in cmd_mapping[cmd].get("arguments", []) - ] - cmd_mapping[cmd] = dict(cmd_mapping[cmd]) - return dict(cmd_mapping) - - def transform_3( self, response: dict[ResponsePrimitive, ResponseType], ) -> dict[AnyStr, dict[AnyStr, ResponseType]]: diff --git a/coredis/response/_callbacks/connection.py b/coredis/response/_callbacks/connection.py index 85083c97f..cd2915aa2 100644 --- a/coredis/response/_callbacks/connection.py +++ b/coredis/response/_callbacks/connection.py @@ -1,8 +1,6 @@ from __future__ import annotations -from coredis._utils import EncodingInsensitiveDict from coredis.response._callbacks import ResponseCallback -from coredis.response._utils import flat_pairs_to_dict from coredis.typing import ( AnyStr, ResponseType, @@ -11,7 +9,6 @@ class ClientTrackingInfoCallback( ResponseCallback[ - ResponseType, ResponseType, dict[AnyStr, AnyStr | set[AnyStr] | list[AnyStr]], ] @@ -19,13 +16,5 @@ class ClientTrackingInfoCallback( def transform( self, response: ResponseType, - ) -> dict[AnyStr, AnyStr | set[AnyStr] | list[AnyStr]]: - response = EncodingInsensitiveDict(flat_pairs_to_dict(response)) - response["flags"] = set(response["flags"]) - return dict(response) - - def transform_3( - self, - response: ResponseType, ) -> dict[AnyStr, AnyStr | set[AnyStr] | list[AnyStr]]: return response diff --git a/coredis/response/_callbacks/geo.py b/coredis/response/_callbacks/geo.py index 5babeae62..0eddc6a8e 100644 --- a/coredis/response/_callbacks/geo.py +++ b/coredis/response/_callbacks/geo.py @@ -10,7 +10,6 @@ class GeoSearchCallback( Generic[AnyStr], ResponseCallback[ - ResponseType, ResponseType, tuple[AnyStr | GeoSearchResult, ...], ], @@ -44,9 +43,7 @@ def transform( return tuple(results) -class GeoCoordinatessCallback( - ResponseCallback[ResponseType, ResponseType, tuple[GeoCoordinates | None, ...]] -): +class GeoCoordinatessCallback(ResponseCallback[ResponseType, tuple[GeoCoordinates | None, ...]]): def transform( self, response: ResponseType, **options: Any ) -> tuple[GeoCoordinates | None, ...]: diff --git a/coredis/response/_callbacks/hash.py b/coredis/response/_callbacks/hash.py index d13814bf1..71b872e2a 100644 --- a/coredis/response/_callbacks/hash.py +++ b/coredis/response/_callbacks/hash.py @@ -14,7 +14,6 @@ class HScanCallback( ResponseCallback[ - list[ResponseType], list[ResponseType], tuple[int, dict[AnyStr, AnyStr] | tuple[AnyStr, ...]], ] @@ -36,27 +35,11 @@ def transform( class HRandFieldCallback( ResponseCallback[ - AnyStr | list[AnyStr] | None, AnyStr | list[AnyStr] | list[list[AnyStr]] | None, AnyStr | tuple[AnyStr, ...] | dict[AnyStr, AnyStr] | None, ] ): def transform( - self, - response: AnyStr | list[AnyStr] | None, - ) -> AnyStr | tuple[AnyStr, ...] | dict[AnyStr, AnyStr] | None: - if not response: - return None - if self.options.get("count"): - assert isinstance(response, list) - if self.options.get("withvalues"): - return flat_pairs_to_dict(response) - else: - return tuple(response) - assert isinstance(response, (str, bytes)) - return response - - def transform_3( self, response: AnyStr | list[AnyStr] | list[list[AnyStr]] | None, ) -> AnyStr | tuple[AnyStr, ...] | dict[AnyStr, AnyStr] | None: @@ -71,14 +54,8 @@ def transform_3( return response -class HGetAllCallback(ResponseCallback[list[AnyStr], dict[AnyStr, AnyStr], dict[AnyStr, AnyStr]]): +class HGetAllCallback(ResponseCallback[dict[AnyStr, AnyStr], dict[AnyStr, AnyStr]]): def transform( - self, - response: list[AnyStr], - ) -> dict[AnyStr, AnyStr]: - return flat_pairs_to_dict(response) if response else {} - - def transform_3( self, response: dict[AnyStr, AnyStr], ) -> dict[AnyStr, AnyStr]: diff --git a/coredis/response/_callbacks/keys.py b/coredis/response/_callbacks/keys.py index 094fcd1eb..5462c5b00 100644 --- a/coredis/response/_callbacks/keys.py +++ b/coredis/response/_callbacks/keys.py @@ -14,7 +14,6 @@ class SortCallback( ResponseCallback[ - int | list[AnyStr], int | list[AnyStr], int | tuple[AnyStr, ...], ] @@ -28,9 +27,7 @@ def transform( return response -class ScanCallback( - ResponseCallback[list[ResponseType], list[ResponseType], tuple[int, tuple[AnyStr, ...]]] -): +class ScanCallback(ResponseCallback[list[ResponseType], tuple[int, tuple[AnyStr, ...]]]): def guard(self, response: list[ResponseType]) -> TypeGuard[tuple[StringT, list[AnyStr]]]: return isinstance(response[0], (str, bytes)) and isinstance(response[1], list) diff --git a/coredis/response/_callbacks/module.py b/coredis/response/_callbacks/module.py index c67281b8d..699f41a00 100644 --- a/coredis/response/_callbacks/module.py +++ b/coredis/response/_callbacks/module.py @@ -1,32 +1,19 @@ from __future__ import annotations -from typing import cast - from coredis.response._callbacks import ResponseCallback -from coredis.response._utils import flat_pairs_to_dict from coredis.typing import ( AnyStr, ResponsePrimitive, - ResponseType, ) class ModuleInfoCallback( ResponseCallback[ - list[list[ResponseType]], list[dict[AnyStr, ResponsePrimitive]], tuple[dict[AnyStr, ResponsePrimitive], ...], ] ): def transform( - self, - response: list[list[ResponseType]], - ) -> tuple[dict[AnyStr, ResponsePrimitive], ...]: - return tuple( - cast(dict[AnyStr, ResponsePrimitive], flat_pairs_to_dict(mod)) for mod in response - ) - - def transform_3( self, response: list[dict[AnyStr, ResponsePrimitive]], ) -> tuple[dict[AnyStr, ResponsePrimitive], ...]: diff --git a/coredis/response/_callbacks/script.py b/coredis/response/_callbacks/script.py index 00cb9dea5..3d9a7cf85 100644 --- a/coredis/response/_callbacks/script.py +++ b/coredis/response/_callbacks/script.py @@ -16,7 +16,7 @@ class FunctionListCallback( - ResponseCallback[list[ResponseType], list[ResponseType], Mapping[AnyStr, LibraryDefinition]] + ResponseCallback[list[ResponseType], Mapping[AnyStr, LibraryDefinition]] ): def transform( self, @@ -48,7 +48,6 @@ def transform( class FunctionStatsCallback( ResponseCallback[ - list[ResponseType], dict[ AnyStr, AnyStr | dict[AnyStr, dict[AnyStr, ResponsePrimitive]] | None, @@ -60,22 +59,6 @@ class FunctionStatsCallback( ] ): def transform( - self, - response: list[ResponseType], - ) -> dict[AnyStr, AnyStr | dict[AnyStr, dict[AnyStr, ResponsePrimitive]] | None]: - transformed = flat_pairs_to_dict(response) - key = cast(AnyStr, b"engines" if b"engines" in transformed else "engines") - engines = flat_pairs_to_dict(cast(list[AnyStr], transformed.pop(key))) - engines_transformed = {} - for engine, stats in engines.items(): - engines_transformed[engine] = flat_pairs_to_dict(cast(list[AnyStr], stats)) - transformed[key] = engines_transformed # type: ignore - return cast( - dict[AnyStr, AnyStr | dict[AnyStr, dict[AnyStr, ResponsePrimitive]]], - transformed, - ) - - def transform_3( self, response: dict[ AnyStr, diff --git a/coredis/response/_callbacks/sentinel.py b/coredis/response/_callbacks/sentinel.py index a3c7507fd..da55c7de2 100644 --- a/coredis/response/_callbacks/sentinel.py +++ b/coredis/response/_callbacks/sentinel.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import cast - from coredis._utils import EncodingInsensitiveDict, nativestr from coredis.response._callbacks import ResponseCallback from coredis.response._callbacks.server import InfoCallback @@ -80,17 +78,10 @@ def parse_sentinel_state( class PrimaryCallback( ResponseCallback[ ResponseType, - dict[ResponsePrimitive, ResponsePrimitive], dict[str, ResponsePrimitive], ] ): def transform( - self, - response: ResponseType, - ) -> dict[str, ResponsePrimitive]: - return parse_sentinel_state(cast(list[ResponsePrimitive], response)).stringify_keys() - - def transform_3( self, response: dict[ResponsePrimitive, ResponsePrimitive], ) -> dict[str, ResponsePrimitive]: @@ -99,24 +90,11 @@ def transform_3( class PrimariesCallback( ResponseCallback[ - list[ResponseType], list[ResponseType], dict[str, dict[str, ResponsePrimitive]], ] ): def transform( - self, - response: list[ResponseType] | dict[ResponsePrimitive, ResponsePrimitive], - ) -> dict[str, dict[str, ResponsePrimitive]]: - result: dict[str, dict[str, ResponseType]] = {} - - for item in response: - state = PrimaryCallback()(item) - result[str(state["name"])] = state - - return result - - def transform_3( self, response: list[ResponseType], ) -> dict[str, dict[str, ResponsePrimitive]]: @@ -129,7 +107,6 @@ def transform_3( class SentinelsStateCallback( ResponseCallback[ - list[ResponseType], list[ResponseType], tuple[dict[str, ResponsePrimitive], ...], ] @@ -137,14 +114,6 @@ class SentinelsStateCallback( def transform( self, response: list[ResponseType], - ) -> tuple[dict[str, ResponsePrimitive], ...]: - return tuple( - parse_sentinel_state([nativestr(i) for i in item]).stringify_keys() for item in response - ) - - def transform_3( - self, - response: list[ResponseType], ) -> tuple[dict[str, ResponsePrimitive], ...]: return tuple( add_flags(EncodingInsensitiveDict(state)).stringify_keys() for state in response @@ -153,7 +122,6 @@ def transform_3( class GetPrimaryCallback( ResponseCallback[ - list[ResponsePrimitive], list[ResponsePrimitive], tuple[str, int] | None, ] @@ -167,7 +135,6 @@ def transform( class SentinelInfoCallback( ResponseCallback[ - list[ResponseType], list[ResponseType], dict[AnyStr, dict[int, dict[str, ResponsePrimitive]]], ] diff --git a/coredis/response/_callbacks/server.py b/coredis/response/_callbacks/server.py index e35ea0887..f23258b16 100644 --- a/coredis/response/_callbacks/server.py +++ b/coredis/response/_callbacks/server.py @@ -17,7 +17,7 @@ ) -class TimeCallback(ResponseCallback[list[AnyStr], list[AnyStr], datetime.datetime]): +class TimeCallback(ResponseCallback[list[AnyStr], datetime.datetime]): def transform( self, response: list[AnyStr], @@ -27,7 +27,7 @@ def transform( ) -class SlowlogCallback(ResponseCallback[ResponseType, ResponseType, tuple[SlowLogInfo, ...]]): +class SlowlogCallback(ResponseCallback[ResponseType, tuple[SlowLogInfo, ...]]): def transform( self, response: ResponseType, @@ -45,7 +45,7 @@ def transform( ) -class ClientInfoCallback(ResponseCallback[ResponseType, ResponseType, ClientInfo]): +class ClientInfoCallback(ResponseCallback[ResponseType, ClientInfo]): INT_FIELDS: ClassVar = { "id", "fd", @@ -81,7 +81,7 @@ def transform( return info -class ClientListCallback(ResponseCallback[ResponseType, ResponseType, tuple[ClientInfo, ...]]): +class ClientListCallback(ResponseCallback[ResponseType, tuple[ClientInfo, ...]]): def transform( self, response: ResponseType, @@ -89,7 +89,7 @@ def transform( return tuple(ClientInfoCallback()(c) for c in response.splitlines()) -class DebugCallback(ResponseCallback[ResponseType, ResponseType, dict[str, str | int]]): +class DebugCallback(ResponseCallback[ResponseType, dict[str, str | int]]): INT_FIELDS: ClassVar = {"refcount", "serializedlength", "lru", "lru_seconds_idle"} def transform( @@ -116,7 +116,6 @@ def transform( class InfoCallback( ResponseCallback[ - StringT, StringT, dict[str, ResponseType], ] @@ -179,7 +178,7 @@ def get_value(value: str) -> ResponseType: return info -class RoleCallback(ResponseCallback[ResponseType, ResponseType, RoleInfo]): +class RoleCallback(ResponseCallback[ResponseType, RoleInfo]): def transform( self, response: ResponseType, @@ -217,7 +216,7 @@ def _parse_sentinel(response: Any) -> Any: class LatencyHistogramCallback( - ResponseCallback[ResponseType, ResponseType, dict[AnyStr, dict[AnyStr, RedisValueT]]] + ResponseCallback[ResponseType, dict[AnyStr, dict[AnyStr, RedisValueT]]] ): def transform( self, @@ -231,9 +230,7 @@ def transform( return histogram -class LatencyCallback( - ResponseCallback[ResponseType, ResponseType, dict[AnyStr, tuple[int, int, int]]] -): +class LatencyCallback(ResponseCallback[ResponseType, dict[AnyStr, tuple[int, int, int]]]): def transform( self, response: ResponseType, diff --git a/coredis/response/_callbacks/sets.py b/coredis/response/_callbacks/sets.py index 300e42774..a5b7259fe 100644 --- a/coredis/response/_callbacks/sets.py +++ b/coredis/response/_callbacks/sets.py @@ -11,9 +11,7 @@ ) -class SScanCallback( - ResponseCallback[list[ResponseType], list[ResponseType], tuple[int, set[AnyStr]]] -): +class SScanCallback(ResponseCallback[list[ResponseType], tuple[int, set[AnyStr]]]): def transform( self, response: list[ResponseType], @@ -26,7 +24,6 @@ def transform( class ItemOrSetCallback( ResponseCallback[ AnyStr | list[ResponsePrimitive] | set[ResponsePrimitive], - AnyStr | set[ResponsePrimitive], AnyStr | set[AnyStr], ] ): diff --git a/coredis/response/_callbacks/sorted_set.py b/coredis/response/_callbacks/sorted_set.py index 4196072e4..da5a51018 100644 --- a/coredis/response/_callbacks/sorted_set.py +++ b/coredis/response/_callbacks/sorted_set.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import SupportsFloat, cast +from typing import cast from coredis.response._callbacks import ResponseCallback from coredis.response.types import ScoredMember, ScoredMembers @@ -14,7 +14,6 @@ class ZRankCallback( ResponseCallback[ - int | list[ResponsePrimitive] | None, int | list[ResponsePrimitive] | None, int | tuple[int, float] | None, ], @@ -22,15 +21,6 @@ class ZRankCallback( def transform( self, response: int | list[ResponsePrimitive] | None, - ) -> int | tuple[int, float] | None: - if self.options.get("withscore"): - return (response[0], float(response[1])) if response else None - else: - return cast(int | None, response) - - def transform_3( - self, - response: int | list[ResponsePrimitive] | None, ) -> int | tuple[int, float] | None: if self.options.get("withscore"): return (response[0], response[1]) if response else None @@ -40,7 +30,6 @@ def transform_3( class ZMembersOrScoredMembers( ResponseCallback[ - list[AnyStr | list[ResponsePrimitive]], list[AnyStr | list[ResponsePrimitive]], tuple[AnyStr | ScoredMember, ...], ], @@ -48,18 +37,6 @@ class ZMembersOrScoredMembers( def transform( self, response: list[AnyStr | list[ResponsePrimitive]], - ) -> tuple[AnyStr | ScoredMember, ...]: - if not response: - return () - elif self.options.get("withscores"): - it = iter(cast(list[AnyStr], response)) - return tuple(ScoredMember(*v) for v in zip(it, map(float, it))) - else: - return cast(tuple[AnyStr, ...], tuple(response)) - - def transform_3( - self, - response: list[AnyStr | list[ResponsePrimitive]], ) -> tuple[AnyStr | ScoredMember, ...]: if self.options.get("withscores"): return tuple(ScoredMember(*v) for v in cast(list[tuple[AnyStr, float]], response)) @@ -69,26 +46,12 @@ def transform_3( class ZSetScorePairCallback( ResponseCallback[ - list[ResponsePrimitive] | None, list[ResponsePrimitive | list[ResponsePrimitive]] | None, ScoredMember | ScoredMembers | None, ], Generic[AnyStr], ): def transform( - self, - response: list[ResponsePrimitive] | None, - ) -> ScoredMember | ScoredMembers | None: - if not response: - return None - - if not (self.options.get("withscores") or self.options.get("count")): - return ScoredMember(cast(AnyStr, response[0]), float(cast(SupportsFloat, response[1]))) - - it = iter(response) - return tuple(ScoredMember(*v) for v in zip(it, map(float, it))) - - def transform_3( self, response: list[ResponsePrimitive | list[ResponsePrimitive]] | None, ) -> ScoredMember | ScoredMembers | None: @@ -103,7 +66,6 @@ def transform_3( class ZMPopCallback( ResponseCallback[ - list[ResponseType] | None, list[ResponseType] | None, tuple[AnyStr, ScoredMembers] | None, ], @@ -120,9 +82,7 @@ def transform( return None -class ZMScoreCallback( - ResponseCallback[list[ResponsePrimitive], list[ResponsePrimitive], tuple[float | None, ...]] -): +class ZMScoreCallback(ResponseCallback[list[ResponsePrimitive], tuple[float | None, ...]]): def transform( self, response: list[ResponsePrimitive], @@ -131,7 +91,7 @@ def transform( class ZScanCallback( - ResponseCallback[list[ResponseType], list[ResponseType], tuple[int, ScoredMembers]], + ResponseCallback[list[ResponseType], tuple[int, ScoredMembers]], Generic[AnyStr], ): def transform( @@ -147,22 +107,11 @@ def transform( class ZRandMemberCallback( ResponseCallback[ - AnyStr | list[ResponsePrimitive] | None, AnyStr | list[list[ResponsePrimitive]] | list[ResponsePrimitive] | None, AnyStr | tuple[AnyStr, ...] | ScoredMembers | None, ] ): def transform( - self, - response: AnyStr | list[ResponsePrimitive] | None, - ) -> AnyStr | tuple[AnyStr, ...] | ScoredMembers | None: - if not (response and self.options.get("withscores")): - return tuple(response) if isinstance(response, list) else response - - it = iter(response) - return tuple(ScoredMember(*v) for v in zip(it, map(float, it))) - - def transform_3( self, response: AnyStr | list[list[ResponsePrimitive]] | list[ResponsePrimitive] | None, ) -> AnyStr | tuple[AnyStr, ...] | ScoredMembers | None: @@ -174,7 +123,6 @@ def transform_3( class BZPopCallback( ResponseCallback[ - list[ResponsePrimitive] | None, list[ResponsePrimitive] | None, tuple[AnyStr, AnyStr, float] | None, ] @@ -188,16 +136,8 @@ def transform( return None -class ZAddCallback(ResponseCallback[ResponsePrimitive, int | float, int | float]): +class ZAddCallback(ResponseCallback[int | float, int | float]): def transform( - self, - response: ResponsePrimitive, - ) -> int | float: - if self.options.get("condition"): - return float(response) - return int(response) - - def transform_3( self, response: int | float, ) -> int | float: diff --git a/coredis/response/_callbacks/streams.py b/coredis/response/_callbacks/streams.py index 6666496ba..915d6de9a 100644 --- a/coredis/response/_callbacks/streams.py +++ b/coredis/response/_callbacks/streams.py @@ -19,7 +19,7 @@ ) -class StreamRangeCallback(ResponseCallback[ResponseType, ResponseType, tuple[StreamEntry, ...]]): +class StreamRangeCallback(ResponseCallback[ResponseType, tuple[StreamEntry, ...]]): def transform( self, response: ResponseType, @@ -27,9 +27,7 @@ def transform( return tuple(StreamEntry(r[0], flat_pairs_to_ordered_dict(r[1])) for r in response) -class ClaimCallback( - ResponseCallback[ResponseType, ResponseType, tuple[AnyStr, ...] | tuple[StreamEntry, ...]] -): +class ClaimCallback(ResponseCallback[ResponseType, tuple[AnyStr, ...] | tuple[StreamEntry, ...]]): def transform( self, response: ResponseType, @@ -42,7 +40,6 @@ def transform( class AutoClaimCallback( ResponseCallback[ - ResponseType, ResponseType, tuple[AnyStr, tuple[AnyStr, ...]] | tuple[AnyStr, tuple[StreamEntry, ...], tuple[AnyStr, ...]], @@ -66,23 +63,8 @@ def transform( class MultiStreamRangeCallback( - ResponseCallback[ResponseType, ResponseType, dict[AnyStr, tuple[StreamEntry, ...]] | None] + ResponseCallback[ResponseType, dict[AnyStr, tuple[StreamEntry, ...]] | None] ): - def transform_3( - self, - response: ResponseType, - ) -> dict[AnyStr, tuple[StreamEntry, ...]] | None: - if response: - mapping: dict[AnyStr, tuple[StreamEntry, ...]] = {} - - for stream_id, entries in response.items(): - mapping[stream_id] = tuple( - StreamEntry(r[0], flat_pairs_to_ordered_dict(r[1])) for r in entries - ) - - return mapping - return None - def transform( self, response: ResponseType, @@ -90,7 +72,7 @@ def transform( if response: mapping: dict[AnyStr, tuple[StreamEntry, ...]] = {} - for stream_id, entries in response: + for stream_id, entries in response.items(): mapping[stream_id] = tuple( StreamEntry(r[0], flat_pairs_to_ordered_dict(r[1])) for r in entries ) @@ -99,9 +81,7 @@ def transform( return None -class PendingCallback( - ResponseCallback[ResponseType, ResponseType, StreamPending | tuple[StreamPendingExt, ...]] -): +class PendingCallback(ResponseCallback[ResponseType, StreamPending | tuple[StreamPendingExt, ...]]): def transform( self, response: ResponseType, @@ -117,7 +97,7 @@ def transform( return tuple(StreamPendingExt(sub[0], sub[1], sub[2], sub[3]) for sub in response) -class XInfoCallback(ResponseCallback[ResponseType, ResponseType, tuple[dict[AnyStr, AnyStr], ...]]): +class XInfoCallback(ResponseCallback[ResponseType, tuple[dict[AnyStr, AnyStr], ...]]): def transform( self, response: ResponseType, @@ -125,7 +105,7 @@ def transform( return tuple(flat_pairs_to_dict(row) for row in response) -class StreamInfoCallback(ResponseCallback[ResponseType, ResponseType, StreamInfo]): +class StreamInfoCallback(ResponseCallback[ResponseType, StreamInfo]): def transform( self, response: ResponseType, diff --git a/coredis/response/_callbacks/strings.py b/coredis/response/_callbacks/strings.py index 6289bfbbf..ed057752d 100644 --- a/coredis/response/_callbacks/strings.py +++ b/coredis/response/_callbacks/strings.py @@ -12,7 +12,7 @@ ) -class StringSetCallback(ResponseCallback[AnyStr | None, AnyStr | None, AnyStr | bool | None]): +class StringSetCallback(ResponseCallback[AnyStr | None, AnyStr | bool | None]): def transform(self, response: AnyStr | None, **options: Any) -> AnyStr | bool | None: if self.options.get("get"): return response @@ -22,35 +22,11 @@ def transform(self, response: AnyStr | None, **options: Any) -> AnyStr | bool | class LCSCallback( ResponseCallback[ - list[ResponseType], dict[ResponsePrimitive, ResponseType], LCSResult, ] ): def transform( - self, - response: (list[ResponseType] | dict[ResponsePrimitive, ResponseType]), - **options: Any, - ) -> LCSResult: - assert ( - isinstance(response, list) - and isinstance(response[-1], int) - and isinstance(response[1], list) - ) - - return LCSResult( - tuple( - LCSMatch( - (int(k[0][0]), int(k[0][1])), - (int(k[1][0]), int(k[1][1])), - k[2] if len(k) > 2 else None, - ) - for k in response[1] - ), - response[-1], - ) - - def transform_3( self, response: dict[ResponsePrimitive, ResponseType], **options: Any, diff --git a/coredis/response/_callbacks/vector_sets.py b/coredis/response/_callbacks/vector_sets.py index 91fe29a56..6ac7a0d65 100644 --- a/coredis/response/_callbacks/vector_sets.py +++ b/coredis/response/_callbacks/vector_sets.py @@ -3,14 +3,12 @@ from coredis._json import json from coredis._utils import nativestr from coredis.response._callbacks import ResponseCallback -from coredis.response._utils import flat_pairs_to_dict from coredis.response.types import VectorData -from coredis.typing import AnyStr, JsonType, ResponsePrimitive, StringT +from coredis.typing import AnyStr, JsonType, ResponsePrimitive class VSimCallback( ResponseCallback[ - list[AnyStr], list[AnyStr] | dict[AnyStr, float | list[float | JsonType]], tuple[AnyStr, ...] | dict[AnyStr, float] @@ -19,30 +17,6 @@ class VSimCallback( ], ): def transform( - self, - response: list[AnyStr], - ) -> ( - tuple[AnyStr, ...] - | dict[AnyStr, float] - | dict[AnyStr, JsonType] - | dict[AnyStr, tuple[float, JsonType]] - ): - withscores, withattribs = self.options.get("withscores"), self.options.get("withattribs") - if withscores or withattribs: - it = iter(response) - match withscores, withattribs: - case True, None | False: - return dict(list(zip(it, map(float, it)))) - case None | False, True: - return dict(list(zip(it, map(json.loads, it)))) - case True, True: - return dict( - list(zip(it, map(lambda x: (float(x[0]), json.loads(x[1])), zip(it, it)))) - ) - else: - return self.transform_3(response) - - def transform_3( self, response: list[AnyStr] | dict[AnyStr, float] @@ -70,23 +44,11 @@ def transform_3( class VLinksCallback( ResponseCallback[ - list[list[AnyStr]] | None, list[list[AnyStr] | dict[AnyStr, float]] | None, tuple[tuple[AnyStr, ...] | dict[AnyStr, float], ...] | None, ], ): def transform( - self, - response: list[list[AnyStr]] | None, - ) -> tuple[tuple[AnyStr, ...] | dict[AnyStr, float], ...] | None: - if response: - if self.options.get("withscores"): - return tuple(dict(zip(it := iter(layer), map(float, it))) for layer in response) - else: - return tuple(tuple(layer) for layer in response) - return None - - def transform_3( self, response: list[list[AnyStr] | dict[AnyStr, float]] | None, ) -> tuple[tuple[AnyStr, ...] | dict[AnyStr, float], ...] | None: @@ -101,28 +63,11 @@ def transform_3( class VEmbCallback( ResponseCallback[ - list[StringT] | list[ResponsePrimitive], list[float] | list[ResponsePrimitive], tuple[float, ...] | VectorData | None, ] ): def transform( - self, - response: list[StringT] | list[ResponsePrimitive] | None, - ) -> tuple[float, ...] | VectorData | None: - if response: - if self.options.get("raw"): - return VectorData( - quantization=nativestr(response[0]), - blob=response[1], - l2_norm=float(response[2]), - quantization_range=float(response[3]) if len(response) == 4 else None, - ) - else: - return tuple(map(float, response)) - return None - - def transform_3( self, response: list[float] | list[ResponsePrimitive] | None, ) -> tuple[float, ...] | VectorData | None: @@ -141,18 +86,11 @@ def transform_3( class VInfoCallback( ResponseCallback[ - list[AnyStr | int] | None, dict[AnyStr, AnyStr | int] | None, dict[AnyStr, AnyStr | int] | None, ] ): def transform( - self, - response: list[AnyStr | int] | None, - ) -> dict[AnyStr, AnyStr | int] | None: - return flat_pairs_to_dict(response) if response else None - - def transform_3( self, response: dict[AnyStr, AnyStr | int] | None, ) -> dict[AnyStr, AnyStr | int] | None: diff --git a/coredis/retry.py b/coredis/retry.py index d62157eac..2b0214f0b 100644 --- a/coredis/retry.py +++ b/coredis/retry.py @@ -1,14 +1,13 @@ from __future__ import annotations -import asyncio -import logging from abc import ABC, abstractmethod from functools import wraps from typing import Any -from coredis.typing import Awaitable, Callable, P, R +from anyio import sleep -logger = logging.getLogger(__name__) +from coredis._utils import logger +from coredis.typing import Awaitable, Callable, P, R class RetryPolicy(ABC): @@ -110,7 +109,7 @@ def __init__( async def delay(self, attempt_number: int) -> None: if attempt_number > 0: - await asyncio.sleep(self.__delay) + await sleep(self.__delay) class ExponentialBackoffRetryPolicy(RetryPolicy): @@ -134,7 +133,7 @@ def __init__( async def delay(self, attempt_number: int) -> None: if attempt_number > 0: - await asyncio.sleep(pow(2, attempt_number) * self.__initial_delay) + await sleep(pow(2, attempt_number) * self.__initial_delay) class CompositeRetryPolicy(RetryPolicy): diff --git a/coredis/sentinel.py b/coredis/sentinel.py index 87928ff27..70b59f155 100644 --- a/coredis/sentinel.py +++ b/coredis/sentinel.py @@ -1,15 +1,17 @@ from __future__ import annotations import random -import ssl -import weakref -from typing import Any, cast, overload +from contextlib import AsyncExitStack, asynccontextmanager +from typing import Any, AsyncGenerator, AsyncIterator, overload + +from anyio import AsyncContextManagerMixin, ConnectionFailed +from anyio.abc import ByteStream +from typing_extensions import Self, override from coredis import Redis from coredis._utils import nativestr from coredis.cache import AbstractCache from coredis.connection import Connection -from coredis.credentials import AbstractCredentialProvider from coredis.exceptions import ( ConnectionError, PrimaryNotFoundError, @@ -30,73 +32,31 @@ class SentinelManagedConnection(Connection, Generic[AnyStr]): - def __init__( - self, - connection_pool: SentinelConnectionPool, - host: str = "127.0.0.1", - port: int = 6379, - username: str | None = None, - password: str | None = None, - credential_provider: AbstractCredentialProvider | None = None, - db: int = 0, - stream_timeout: float | None = None, - connect_timeout: float | None = None, - ssl_context: ssl.SSLContext | None = None, - encoding: str = "utf-8", - decode_responses: bool = False, - socket_keepalive: bool | None = None, - socket_keepalive_options: dict[int, int | bytes] | None = None, - *, - client_name: str | None = None, - protocol_version: Literal[2, 3] = 3, - ): - self.connection_pool: SentinelConnectionPool = weakref.proxy(connection_pool) - super().__init__( - host=host, - port=port, - username=username, - password=password, - credential_provider=credential_provider, - db=db, - stream_timeout=stream_timeout, - connect_timeout=connect_timeout, - ssl_context=ssl_context, - encoding=encoding, - decode_responses=decode_responses, - socket_keepalive=socket_keepalive, - socket_keepalive_options=socket_keepalive_options, - client_name=client_name, - protocol_version=protocol_version, - ) + def __init__(self, connection_pool: SentinelConnectionPool, **kwargs: Any): + self.connection_pool: SentinelConnectionPool = connection_pool + super().__init__(**kwargs) def __repr__(self) -> str: pool = self.connection_pool - if self.host: host_info = f",host={self.host},port={self.port}" else: host_info = "" - s = f"{type(self).__name__}" - - return s + return f"{type(self).__name__}" - async def connect_to(self, address: tuple[str, int]) -> None: - self.host, self.port = address - await super().connect() - - async def connect(self) -> None: - if not self.is_connected: - if self.connection_pool.is_primary: - await self.connect_to(await self.connection_pool.get_primary_address()) - else: - for replica in await self.connection_pool.rotate_replicas(): - try: - return await self.connect_to(replica) - except ConnectionError: - continue - raise ReplicaNotFoundError # Never be here - - return None + @override + async def _connect(self) -> ByteStream: + if self.connection_pool.is_primary: + self.host, self.port = await self.connection_pool.get_primary_address() + return await super()._connect() + else: + async for replica in self.connection_pool.rotate_replicas(): + try: + self.host, self.port = replica + return await super()._connect() + except ConnectionFailed: + continue + raise ReplicaNotFoundError # Never be here class SentinelConnectionPool(ConnectionPool): @@ -104,78 +64,59 @@ class SentinelConnectionPool(ConnectionPool): Sentinel backed connection pool. """ - primary_address: tuple[str, int] | None - replica_counter: int | None - def __init__( self, service_name: StringT, sentinel_manager: Sentinel[Any], is_primary: bool = True, - check_connection: bool = True, + check_connection: bool = False, **kwargs: Any, ): self.is_primary = is_primary - kwargs["connection_class"] = cast( - type[Connection], - kwargs.get( - "connection_class", - SentinelManagedConnection[AnyStr], # type: ignore - ), - ) + kwargs["connection_class"] = SentinelManagedConnection super().__init__(**kwargs) self.connection_kwargs["connection_pool"] = self self.service_name = nativestr(service_name) self.sentinel_manager = sentinel_manager self.check_connection = check_connection + self.primary_address: tuple[str, int] | None = None + self.replica_counter: int | None = None def __repr__(self) -> str: return ( f"{type(self).__name__}" f"" ) - def reset(self) -> None: - super().reset() - self.primary_address = None - self.replica_counter = None - async def get_primary_address(self) -> tuple[str, int]: primary_address = await self.sentinel_manager.discover_primary(self.service_name) - if self.is_primary: - if self.primary_address is None: - self.primary_address = primary_address - elif primary_address != self.primary_address: + if self.primary_address != primary_address and self.primary_address is not None: # Primary address changed, disconnect all clients in this pool - self.disconnect() + self._task_group.cancel_scope.cancel() + self.primary_address = primary_address return primary_address - async def rotate_replicas(self) -> list[tuple[str, int]]: + async def rotate_replicas(self) -> AsyncIterator[tuple[str, int]]: """Round-robin replicas balancer""" replicas = await self.sentinel_manager.discover_replicas(self.service_name) - replica_addresses: list[tuple[str, int]] = [] - if replicas: if self.replica_counter is None: self.replica_counter = random.randint(0, len(replicas) - 1) - for _ in range(len(replicas)): self.replica_counter = (self.replica_counter + 1) % len(replicas) - replica_addresses.append(replicas[self.replica_counter]) - - return replica_addresses - # Fallback to primary - try: - return [await self.get_primary_address()] - except PrimaryNotFoundError: - pass - raise ReplicaNotFoundError(f"No replica found for {self.service_name!r}") + yield replicas[self.replica_counter] + else: + try: + yield await self.get_primary_address() + except PrimaryNotFoundError: + pass + raise ReplicaNotFoundError(f"No replica found for {self.service_name!r}") -class Sentinel(Generic[AnyStr]): +class Sentinel(AsyncContextManagerMixin, Generic[AnyStr]): """ Example use:: @@ -237,8 +178,8 @@ def __init__( :param sentinel_kwargs: is a dictionary of connection arguments used when connecting to sentinel instances. Any argument that can be passed to a normal Redis connection can be specified here. If :paramref:`sentinel_kwargs` is - not specified, ``stream_timeout``, ``socket_keepalive``, ``decode_responses`` - and ``protocol_version`` options specified in :paramref:`connection_kwargs` will be used. + not specified, ``stream_timeout``, ``socket_keepalive`` and ``decode_responses`` + options specified in :paramref:`connection_kwargs` will be used. :param cache: If provided the cache will be shared between both primaries and replicas returned by this sentinel. :param type_adapter: The adapter to use for serializing / deserializing customs types @@ -250,20 +191,18 @@ def __init__( """ # if sentinel_kwargs isn't defined, use the socket_* options from # connection_kwargs - - if not sentinel_kwargs: + if sentinel_kwargs is None: sentinel_kwargs = { k: v - for k, v in iter(connection_kwargs.items()) + for k, v in connection_kwargs.items() if k in { + "connect_timeout", "socket_timeout", "socket_keepalive", "encoding", - "protocol_version", } } - self.sentinel_kwargs = sentinel_kwargs self.min_other_sentinels = min_other_sentinels self.connection_kwargs = connection_kwargs @@ -272,14 +211,19 @@ def __init__( self.connection_kwargs["decode_responses"] = self.sentinel_kwargs["decode_responses"] = ( decode_responses ) - self.sentinels = [ Redis(hostname, port, **self.sentinel_kwargs) for hostname, port in sentinels ] + @asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: + async with AsyncExitStack() as stack: + for s in self.sentinels: + await stack.enter_async_context(s.__asynccontextmanager__()) + yield self + def __repr__(self) -> str: sentinel_addresses: list[str] = [] - for sentinel in self.sentinels: sentinel_addresses.append( "{}:{}".format( @@ -287,7 +231,6 @@ def __repr__(self) -> str: sentinel.connection_pool.connection_kwargs["port"], ) ) - return "{}".format(type(self).__name__, ",".join(sentinel_addresses)) def __check_primary_state( @@ -296,10 +239,8 @@ def __check_primary_state( ) -> bool: if not state["is_master"] or state["is_sdown"] or state["is_odown"]: return False - if int(state["num-other-sentinels"] or 0) < self.min_other_sentinels: return False - return True def __filter_replicas( @@ -307,14 +248,12 @@ def __filter_replicas( ) -> list[tuple[str, int]]: """Removes replicas that are in an ODOWN or SDOWN state""" replicas_alive: list[tuple[str, int]] = [] - for replica in replicas: if replica["is_odown"] or replica["is_sdown"]: continue ip, port = replica["ip"], replica["port"] assert ip and port replicas_alive.append((nativestr(ip), int(port))) - return replicas_alive async def discover_primary(self, service_name: str) -> tuple[str, int]: @@ -325,7 +264,6 @@ async def discover_primary(self, service_name: str) -> tuple[str, int]: :return: A pair (address, port) or raises :exc:`~coredis.exceptions.PrimaryNotFoundError` if no primary is found. """ - for sentinel_no, sentinel in enumerate(self.sentinels): try: primaries = await sentinel.sentinel_masters() @@ -335,27 +273,19 @@ async def discover_primary(self, service_name: str) -> tuple[str, int]: if state and self.__check_primary_state(state): # Put this sentinel at the top of the list - self.sentinels[0], self.sentinels[sentinel_no] = ( - sentinel, - self.sentinels[0], - ) - + self.sentinels[0] = sentinel + self.sentinels[sentinel_no] = self.sentinels[0] return nativestr(state["ip"]), int(state["port"] or -1) raise PrimaryNotFoundError(f"No primary found for {service_name!r}") async def discover_replicas(self, service_name: str) -> list[tuple[str, int]]: """Returns a list of alive replicas for service :paramref:`service_name`""" - for sentinel in self.sentinels: try: replicas = await sentinel.sentinel_replicas(service_name) except (ConnectionError, ResponseError, TimeoutError): continue - filtered_replicas = self.__filter_replicas(replicas) - - if filtered_replicas: - return filtered_replicas - + return self.__filter_replicas(replicas) return [] @overload diff --git a/coredis/typing.py b/coredis/typing.py index a2492cb42..aab640128 100644 --- a/coredis/typing.py +++ b/coredis/typing.py @@ -25,6 +25,7 @@ from types import GenericAlias, ModuleType, UnionType from typing import ( TYPE_CHECKING, + Annotated, Any, AnyStr, ClassVar, @@ -137,7 +138,7 @@ class ExecutionParameters(TypedDict): #: Represents the acceptable types of a redis key -KeyT = str | bytes +KeyT = Annotated[str | bytes, "KeyT"] class Serializable(Generic[R]): diff --git a/docs/source/api/bitfield.rst b/docs/source/api/bitfield.rst index eb0538584..fdc27dafa 100644 --- a/docs/source/api/bitfield.rst +++ b/docs/source/api/bitfield.rst @@ -4,4 +4,3 @@ Bitfield operations .. autoclass:: coredis.commands.BitFieldOperation :no-inherited-members: :class-doc-from: both - diff --git a/docs/source/api/caching.rst b/docs/source/api/caching.rst index fcd7e6839..926010edc 100644 --- a/docs/source/api/caching.rst +++ b/docs/source/api/caching.rst @@ -5,13 +5,7 @@ Caching Built in caches ^^^^^^^^^^^^^^^ -.. autoclass:: coredis.cache.TrackingCache - :class-doc-from: both - -.. autoclass:: coredis.cache.NodeTrackingCache - :class-doc-from: both - -.. autoclass:: coredis.cache.ClusterTrackingCache +.. autoclass:: coredis.cache.LRUCache :class-doc-from: both Implementing a custom cache @@ -22,3 +16,11 @@ must implement :class:`~coredis.cache.AbstractCache` .. autoclass:: coredis.cache.AbstractCache .. autoclass:: coredis.cache.CacheStats +Internal cache wrappers +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: coredis.cache.NodeTrackingCache + :class-doc-from: both + +.. autoclass:: coredis.cache.ClusterTrackingCache + :class-doc-from: both diff --git a/docs/source/api/clients.rst b/docs/source/api/clients.rst index ace71fc57..b575a2f8e 100644 --- a/docs/source/api/clients.rst +++ b/docs/source/api/clients.rst @@ -3,11 +3,9 @@ Clients .. autoclass:: coredis.Redis :class-doc-from: both - .. autoclass:: coredis.RedisCluster :class-doc-from: both - :mod:`coredis.sentinel` .. autoclass:: coredis.sentinel.Sentinel @@ -16,7 +14,7 @@ Clients Redis Command related types ^^^^^^^^^^^^^^^^^^^^^^^^^^^ The following classes and types are used in the internals of coredis -to wire arguments to python command functions representing redis commands +to wire arguments to Python command functions representing Redis commands to the expected RESP syntax and eventually send it to a connection and back to the client with a pythonic response mapped from the RESP response @@ -29,4 +27,4 @@ to the client with a pythonic response mapped from the RESP response .. autoclass:: coredis.typing.ExecutionParameters :class-doc-from: both :show-inheritance: - :no-inherited-members: \ No newline at end of file + :no-inherited-members: diff --git a/docs/source/api/connections.rst b/docs/source/api/connections.rst index 04cfeecf4..5a3183005 100644 --- a/docs/source/api/connections.rst +++ b/docs/source/api/connections.rst @@ -7,28 +7,17 @@ Connection Pools :mod:`coredis` - .. autoclass:: coredis.ConnectionPool :class-doc-from: both -.. autoclass:: coredis.BlockingConnectionPool - :class-doc-from: both - :show-inheritance: - .. autoclass:: coredis.ClusterConnectionPool :class-doc-from: both :show-inheritance: -.. autoclass:: coredis.BlockingClusterConnectionPool - :class-doc-from: both - :show-inheritance: - .. autoclass:: coredis.sentinel.SentinelConnectionPool :class-doc-from: both :show-inheritance: - - Connection Classes ^^^^^^^^^^^^^^^^^^ :mod:`coredis` @@ -54,4 +43,3 @@ All connection classes derive from the same base-class: .. autoclass:: coredis.BaseConnection :show-inheritance: :class-doc-from: both - diff --git a/docs/source/api/credentials.rst b/docs/source/api/credentials.rst index 5d481de8c..6378c447a 100644 --- a/docs/source/api/credentials.rst +++ b/docs/source/api/credentials.rst @@ -7,7 +7,6 @@ Credential Providers ~coredis.credentials.UserPassCredentialProvider ~coredis.credentials.UserPass - .. autoclass:: coredis.credentials.AbstractCredentialProvider :class-doc-from: both @@ -17,6 +16,5 @@ Credential Providers .. autoclass:: coredis.credentials.UserPass :no-inherited-members: - .. autoclass:: coredis.recipes.credentials.ElastiCacheIAMProvider - :class-doc-from: both \ No newline at end of file + :class-doc-from: both diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 583252586..bc56ecdc8 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -19,4 +19,3 @@ API Documentation utilities errors credentials - diff --git a/docs/source/api/modules.rst b/docs/source/api/modules.rst index cac7993a8..ec786a626 100644 --- a/docs/source/api/modules.rst +++ b/docs/source/api/modules.rst @@ -24,8 +24,6 @@ To access the :class:`~coredis.modules.Json` command group from the :class:`core json = coredis.modules.Json(client) await json.get("key", "$") - - RedisJSON ^^^^^^^^^ .. autoclass:: coredis.modules.Json @@ -81,5 +79,3 @@ Autocomplete TimeSeries ^^^^^^^^^^ .. autoclass:: coredis.modules.TimeSeries - - diff --git a/docs/source/api/pubsub.rst b/docs/source/api/pubsub.rst index 2179e4bc9..90758c17f 100644 --- a/docs/source/api/pubsub.rst +++ b/docs/source/api/pubsub.rst @@ -17,5 +17,3 @@ PubSub :class-doc-from: both .. autodata:: coredis.commands.pubsub.SubscriptionCallback - - diff --git a/docs/source/api/scripting.rst b/docs/source/api/scripting.rst index 8426f2828..3479a419b 100644 --- a/docs/source/api/scripting.rst +++ b/docs/source/api/scripting.rst @@ -9,13 +9,13 @@ LUA Scripts :class-doc-from: both :special-members: __call__ - Redis Functions ^^^^^^^^^^^^^^^ .. autoclass:: coredis.commands.Library :class-doc-from: both +.. autofunction:: coredis.commands.wraps + .. autoclass:: coredis.commands.Function :class-doc-from: both :special-members: __call__ - diff --git a/docs/source/api/streams.rst b/docs/source/api/streams.rst index e51280a4f..9a6a6774f 100644 --- a/docs/source/api/streams.rst +++ b/docs/source/api/streams.rst @@ -15,4 +15,3 @@ Stream Consumers .. autoclass:: coredis.stream.StreamParameters :show-inheritance: :no-inherited-members: - diff --git a/docs/source/api/typing.rst b/docs/source/api/typing.rst index 779108578..98cc30203 100644 --- a/docs/source/api/typing.rst +++ b/docs/source/api/typing.rst @@ -31,7 +31,6 @@ Custom types .. autoclass:: coredis.typing.TypeAdapter :class-doc-from: both - Redis Response (RESP) descriptions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -55,7 +54,6 @@ to the returns documented in the client API at :ref:`api/clients:clients`. The total structure of any response for any redis command. - Response Types ^^^^^^^^^^^^^^ In most cases the API returns native python types mapped as closely as possible @@ -85,4 +83,3 @@ returned by redis - to avoid errors in indexing. .. automodule:: coredis.modules.response.types :no-inherited-members: :show-inheritance: - diff --git a/docs/source/api/utilities.rst b/docs/source/api/utilities.rst index 61237c7f3..ff3c116de 100644 --- a/docs/source/api/utilities.rst +++ b/docs/source/api/utilities.rst @@ -8,11 +8,6 @@ Enums :no-inherited-members: :show-inheritance: -Monitor -^^^^^^^ -.. autoclass:: coredis.commands.Monitor - :class-doc-from: both - Retries ^^^^^^^ :mod:`coredis.retry` @@ -27,4 +22,3 @@ Utilities for managing errors that can be recovered from by providing retry poli All retry policies need to derive from :class:`coredis.retry.RetryPolicy` .. autoclass:: coredis.retry.RetryPolicy - diff --git a/docs/source/handbook/.cluster.rst.swp b/docs/source/handbook/.cluster.rst.swp deleted file mode 100644 index f8027e56e..000000000 Binary files a/docs/source/handbook/.cluster.rst.swp and /dev/null differ diff --git a/docs/source/handbook/caching.rst b/docs/source/handbook/caching.rst index 1a968d4fe..8918da656 100644 --- a/docs/source/handbook/caching.rst +++ b/docs/source/handbook/caching.rst @@ -11,10 +11,8 @@ implementing the :class:`~coredis.cache.AbstractCache` interface and will: 1. Cache responses for readonly commands acting on single keys (the docstring for the method will indicate whether it supports caching, for example :meth:`~coredis.Redis.get`). -2. Return cached responses when they are found if the cache is returning healthy via - the :data:`~coredis.cache.AbstractCache.healthy` property -3. Invalidate the entire cache for a key if a non readonly command is called -4. If the cache returns a :data:`~coredis.cache.AbstractCache.confidence` value lower +2. Invalidate a key if a non readonly command is called on it +3. If the cache returns a :data:`~coredis.cache.AbstractCache.confidence` value lower than ``100`` the client will distrust the cached response ``(100-$confidence)%`` of the time and validate the cached response against the actual response from the server. The result of the comparison will be provided to the cache through a call to :meth:`~coredis.cache.AbstractCache.feedback` and @@ -31,58 +29,50 @@ sends a notification that the cache subscribes to to invalidate the cache. Specifically :class:`~coredis.cache.NodeTrackingCache` contains the implementation for a single node and :class:`~coredis.cache.ClusterTrackingCache` tracks all the nodes in a redis cluster. -For convenience a proxy class that automatically picks the right implementation based on the -client is available as :class:`~coredis.cache.TrackingCache`. - +Users don't need to worry about how these implementations work, and instead can focus on implementing +a :class:`~coredis.cache.AbstractCache` instance or using the provided implementation, :class:`~coredis.cache.LRUCache`. For example:: - - import asyncio + import trio import coredis - from coredis.cache import TrackingCache + from coredis.cache import LRUCache - cached_client = coredis.Redis(cache=TrackingCache()) + cached_client = coredis.Redis(cache=LRUCache()) regular_client = coredis.Redis() # or in cluster mode - # cached_client = coredis.RedisCluster("localhost", 7000, cache=TrackingCache()) + # cached_client = coredis.RedisCluster("localhost", 7000, cache=LRUCache()) # regular_client = coredis.RedisCluster("localhost", 7000) async def test(): - assert not await cached_client.get("fubar") # None response cached - await regular_client.set("fubar", "bar") # <- triggers a push message to cached_client - await asyncio.sleep(0.01) - assert b"bar" == await cached_client.get("fubar") # Cache should be invalidated - assert b"bar" == await cached_client.get("fubar") # Fetched from local cache - await cached_client.delete(["fubar"]) # Invalidates local cache immediately - assert not await cached_client.get("fubar") - - asyncio.run(test()) - - -:class:`~coredis.cache.TrackingCache` exposes a few configuration options to fine tune + async with cached_client, regular_client: + assert not await cached_client.get("fubar") # None response cached + await regular_client.set("fubar", "bar") # <- triggers a push message to cached_client + await trio.sleep(0.01) + assert b"bar" == await cached_client.get("fubar") # Cache should be invalidated + assert b"bar" == await cached_client.get("fubar") # Fetched from local cache + await cached_client.delete(["fubar"]) # Invalidates local cache immediately + assert not await cached_client.get("fubar") + + trio.run(test()) + +:class:`~coredis.cache.LRUCache` exposes a few configuration options to fine tune the cache. Specifically the following constructor arguments might be of interest: -:paramref:`~coredis.cache.TrackingCache.max_size_bytes` +:paramref:`~coredis.cache.LRUCache.max_size_bytes` Maximum size in bytes that the cache should be allowed to grow to. The cache will periodically shrink the cache in an LRU manner until it is below the threshold. -:paramref:`~coredis.cache.TrackingCache.max_keys` +:paramref:`~coredis.cache.LRUCache.max_keys` Maximum number of redis keys to track. This does not map directly to the number of cached entries as the cache maintains a per key, per command, per argument cache. -:paramref:`~coredis.cache.TrackingCache.max_idle_time_seconds` - Maximum time to tolerate no repsonse from the server. The cache instance will - use the ``PING`` command to verify if the server is responsive even if no invalidation - notifications have been received and if the threshold is breached the in memory cache - will be reset and the cache marked unhealthy. - -:paramref:`~coredis.cache.TrackingCache.confidence` +:paramref:`~coredis.cache.LRUCache.confidence` Confidence % in the cache. The client will sample cached values based on the confidence and if the cached value is not the same as the actual response from the server the actual value will be returned and the tainted key invalidated. -:paramref:`~coredis.cache.TrackingCache.dynamic_confidence` +:paramref:`~coredis.cache.LRUCache.dynamic_confidence` If set to ``True`` the cache will adjust it's confidence based on sampled (sampling depends on the initial confidence value itself) validations. diff --git a/docs/source/handbook/cluster.rst b/docs/source/handbook/cluster.rst index 412617d28..0874a9cdc 100644 --- a/docs/source/handbook/cluster.rst +++ b/docs/source/handbook/cluster.rst @@ -39,12 +39,12 @@ value of :paramref:`~coredis.RedisCluster.ensure_replication.timeout_ms`), else raise a :exc:`~coredis.exceptions.ReplicationError`:: import asyncio - import coredis + from coredis import RedisCluster async def test(): - client = coredis.RedisCluster("localhost", 7000) + client = RedisCluster("localhost", 7000, startup_nodes=[...]) with client.ensure_replication(replicas=2): await client.set("fubar", 1) - asyncio.run(test()) \ No newline at end of file + asyncio.run(test()) diff --git a/docs/source/handbook/connections.rst b/docs/source/handbook/connections.rst index 02726eb2b..f0929cbe5 100644 --- a/docs/source/handbook/connections.rst +++ b/docs/source/handbook/connections.rst @@ -5,9 +5,9 @@ Connection Pools ---------------- Both :class:`~coredis.Redis` and :class:`~coredis.RedisCluster` are backed by a connection -pool that manages the underlying connections to the redis server(s). **coredis** supports -both blocking and non-blocking connection pools. The default pool that is allocated is a -non-blocking connection pool. +pool that manages the underlying connections to the Redis server(s). **coredis** connection +pools are blocking and multiplex most kinds of commands over a few connections, while +allocating dedicated connections to blocking commands, pubsub instances, and pipelines. To explicitly select the type of connection pool used pass in the appropriate class as :paramref:`coredis.Redis.connection_pool_cls` or :paramref:`coredis.RedisCluster.connection_pool_cls`. @@ -15,9 +15,9 @@ To explicitly select the type of connection pool used pass in the appropriate cl Connection pools can also be shared between multiple clients through the :paramref:`coredis.Redis.connection_pool` or :paramref:`coredis.RedisCluster.connection_pool` parameter. -============================ -Non-Blocking Connection Pool -============================ +=============== +Connection Pool +=============== Standalone :class:`~coredis.pool.ConnectionPool` @@ -25,31 +25,33 @@ Standalone Cluster :class:`~coredis.pool.ClusterConnectionPool` -The default non-blocking connection pools that are allocated to clients will only allow -upto ``max_connections`` connections to be acquired concurrently, and if more are requested -they will raise an exception. +Connection pools will only allow up to ``max_connections`` connections to be running +concurrently, and if more are requested the command will block until one becomes +available. Since most commands can be multiplexed over a few connections this is rare +in practice unless you're using many pipelines/blocking commands/pubsubs simultaneously. -In the following example, a client is created with ``max_connections`` set to ``2``, however ``10`` -blocking requests are concurrently started. This means ~ ``8`` requests will fail:: +In the following example, a client is created with ``max_connections`` set to ``8``, +however ``10`` blocking requests are concurrently started. This means ``2`` requests will +block:: import coredis import asyncio + from anyio import fail_after async def test(): - client = coredis.Redis(max_connections=2) + client = coredis.Redis(max_connections=8) # or with cluster # client = coredis.RedisCluster( # "localhost", 7000, - # max_connections=2, max_connections_per_node=True + # max_connections=8, max_connections_per_node=True # ) - await client.set("fubar", 1) - results = await asyncio.gather( - *[client.get("fubar") for _ in range(10)], - return_exceptions=True - ) - print(len([r for r in results if isinstance(r, Exception)])) - assert len([r for r in results if isinstance(r, Exception)]) == 8 + async with client: + with fail_after(4): + results = await asyncio.gather( + *[client.blpop(["fubar"], 3) for _ in range(10)], + return_exceptions=True + ) asyncio.run(test()) @@ -58,78 +60,33 @@ Changing ``max_connections`` to ``10`` will result in all requests succeeding:: import coredis import asyncio + from anyio import fail_after async def test(): client = coredis.Redis(max_connections=10) # or with cluster # client = coredis.RedisCluster( # "localhost", 7000, - # max_connections=2, max_connections_per_node=True + # max_connections=10, max_connections_per_node=True # ) - await client.set("fubar", 1) - results = await asyncio.gather( - *[client.get("fubar") for _ in range(10)], - return_exceptions=True - ) - assert len([r for r in results if isinstance(r, Exception)]) == 0 + async with client: + with fail_after(4): + results = await asyncio.gather( + *[client.blpop(["fubar"], 3) for _ in range(10)], + return_exceptions=True + ) asyncio.run(test()) -======================== -Blocking Connection Pool -======================== - -Standalone - :class:`~coredis.pool.BlockingConnectionPool` - -Cluster - :class:`~coredis.pool.BlockingClusterConnectionPool` - -Re-using the example from the :ref:`handbook/connections:non-blocking connection pool` section above, -but using the blocking variants of the connection pools for parameters :paramref:`coredis.Redis.connection_pool_cls` or :paramref:`coredis.RedisCluster.connection_pool_cls` -and setting ``max_connections`` to ``2`` will not result in any requests failing but instead blocking to re-use -the ``2`` connections in the pool:: - - - import coredis - import asyncio - - async def test(): - client = coredis.Redis( - connection_pool_cls=coredis.BlockingConnectionPool, - max_connections=2 - ) - # or with cluster - # client = coredis.RedisCluster( - # "localhost", 7000, - # connection_pool_cls=coredis.BlockingClusterConnectionPool, - # max_connections=2, - # max_connections_per_node=True - # ) - - await client.set("fubar", 1) - results = await asyncio.gather( - *[client.get("fubar") for _ in range(10)], - return_exceptions=True - ) - assert len([r for r in results if isinstance(r, Exception)]) == 0 - - asyncio.run(test()) - -.. note:: For :class:`~coredis.pool.BlockingClusterConnectionPool` the - :paramref:`~coredis.pool.BlockingClusterConnectionPool.max_connections_per_node` - controls whether the value of :paramref:`~coredis.pool.BlockingClusterConnectionPool.max_connections` - is used cluster wide or per node. - Connection types ---------------- coredis ships with three types of connections. -- The default, :class:`coredis.connection.Connection`, is a normal TCP socket based connection. +- The default, :class:`coredis.connection.Connection`, is a normal TCP socket-based connection. - :class:`~coredis.connection.UnixDomainSocketConnection` allows - for clients running on the same device as the server to connect via a unix domain socket. + for clients running on the same device as the server to connect via a Unix domain socket. To use a :class:`~coredis.connection.UnixDomainSocketConnection` connection, simply pass the :paramref:`~coredis.Redis.unix_socket_path` argument, which is a string to the unix domain socket file. @@ -159,7 +116,4 @@ specified during initialization. .. code-block:: python - pool = coredis.ConnectionPool(connection_class=YourConnectionClass, - your_arg='...', ...) - - + pool = coredis.ConnectionPool(connection_class=YourConnectionClass, ...) diff --git a/docs/source/handbook/development.rst b/docs/source/handbook/development.rst index 7f9ab0f73..e887f7495 100644 --- a/docs/source/handbook/development.rst +++ b/docs/source/handbook/development.rst @@ -17,6 +17,14 @@ The unit tests will lazily initialize the containers required per test using the $ uv run pytest tests -To reduce unnecessary setup and tear down the containers are left running after the tests complete. To cleanup:: +To reduce unnecessary setup and tear down the containers are left running after the tests complete. To clean up: - docker-compose down --remove-orphans +.. code-block:: bash + + $ docker-compose down --remove-orphans + +You can run single tests or filter out certain client types like this: + +.. code-block:: bash + + $ pytest -m 'basic and not (raw or resp2 or cached)' tests/commands/test_string.py diff --git a/docs/source/handbook/encoding.rst b/docs/source/handbook/encoding.rst index 107520465..75a60875d 100644 --- a/docs/source/handbook/encoding.rst +++ b/docs/source/handbook/encoding.rst @@ -14,11 +14,12 @@ The behavior of the client can also be temporarily changed by using the :meth:`~ context manager. For example:: client = coredis.Redis(decoding=True, encoding='utf-8') - await client.set("fubar", "baz") - with client.decoding(False): - assert await client.get("fubar") == b"baz" - with client.decoding(True): - assert await client.get("fubar") == "baz" + async with client: + await client.set("fubar", "baz") + with client.decoding(False): + assert await client.get("fubar") == b"baz" + with client.decoding(True): + assert await client.get("fubar") == "baz" .. note:: In certain cases (exclusively for utility commands such as :meth:`coredis.Redis.info`) diff --git a/docs/source/handbook/index.rst b/docs/source/handbook/index.rst index 9a103c4dc..f0a2a6afa 100644 --- a/docs/source/handbook/index.rst +++ b/docs/source/handbook/index.rst @@ -17,7 +17,6 @@ Handbook modules connections encoding - response optimization typing development diff --git a/docs/source/handbook/modules.rst b/docs/source/handbook/modules.rst index 2ee3c36aa..b7e003bad 100644 --- a/docs/source/handbook/modules.rst +++ b/docs/source/handbook/modules.rst @@ -12,16 +12,17 @@ of :class:`Redis` or :class:`RedisCluster`. For example:: client = coredis.Redis() - # RedisJSON - await client.json.get("key") - # RediSearch - await client.search.search("index", "*") - # RedisBloom:BloomFilter - await client.bf.reserve("bf", 0.001, 1000) - # RedisBloom:CuckooFilter - await client.cf.reserve("cf", 1000) - # RedisTimeSeries - await client.timeseries.add("ts", 1, 1) + async with client: + # RedisJSON + await client.json.get("key") + # RediSearch + await client.search.search("index", "*") + # RedisBloom:BloomFilter + await client.bf.reserve("bf", 0.001, 1000) + # RedisBloom:CuckooFilter + await client.cf.reserve("cf", 1000) + # RedisTimeSeries + await client.timeseries.add("ts", 1, 1) Module commands can also be used in :ref:`handbook/pipelines:pipelines` (and transactions) @@ -29,17 +30,14 @@ by accessing them via the command group property in the same way as described ab For example:: - pipeline = await client.pipeline() - - await pipeline.json.get("key") - await pipeline.json.get("key") - await pipeline.search.search("index", "*") - await pipeline.bf.reserve("bf", 0.001, 1000) - await pipeline.cf.reserve("cf", 1000) - await pipeline.timeseries.add("ts", 1, 1) - await pipeline.graph.query("graph", "CREATE (:Node {name: 'Node'})") - - await pipeline.execute() + async with client.pipeline(transaction=True) as pipe: + pipe.json.get("key") + pipe.json.get("key") + pipe.search.search("index", "*") + pipe.bf.reserve("bf", 0.001, 1000) + pipe.cf.reserve("cf", 1000) + pipe.timeseries.add("ts", 1, 1) + pipe.graph.query("graph", "CREATE (:Node {name: 'Node'})") RedisJSON @@ -55,17 +53,18 @@ Get/set operations:: import coredis client = coredis.Redis() - await client.json.set( - "key1", ".", {"a": 1, "b": [1, 2, 3], "c": "str"} - ) - assert 1 == await client.json.get("key1", ".a") - assert [1,2,3] == await client.json.get("key1", ".b") - assert "str" == await client.json.get("key1", ".c") + async with client: + await client.json.set( + "key1", ".", {"a": 1, "b": [1, 2, 3], "c": "str"} + ) + assert 1 == await client.json.get("key1", ".a") + assert [1,2,3] == await client.json.get("key1", ".b") + assert "str" == await client.json.get("key1", ".c") - await client.json.set("key2", ".", {"a": 2, "b": [4,5,6], "c": ["str"]}) + await client.json.set("key2", ".", {"a": 2, "b": [4,5,6], "c": ["str"]}) - # multi get - assert ["str", ["str"]] == await client.json.mget(["key1", "key2"], ".c") + # multi get + assert ["str", ["str"]] == await client.json.mget(["key1", "key2"], ".c") Clear versus Delete:: @@ -144,49 +143,47 @@ some common field definitions:: import coredis import coredis.modules client = coredis.Redis(decode_responses=True) - - # Create an index on json documents - await client.search.create("json_index", on=coredis.PureToken.JSON, schema = [ - coredis.modules.search.Field('$.name', coredis.PureToken.TEXT, alias='name'), - coredis.modules.search.Field('$.country', coredis.PureToken.TEXT, alias='country'), - coredis.modules.search.Field('$.population', coredis.PureToken.NUMERIC, alias='population'), - coredis.modules.search.Field("$.location", coredis.PureToken.GEO, alias='location'), - coredis.modules.search.Field('$.iso_tags', coredis.PureToken.TAG, alias='iso_tags'), - coredis.modules.search.Field('$.summary_vector', coredis.PureToken.VECTOR, alias='summary_vector', - algorithm="FLAT", - attributes={ - "DIM": 768, - "DISTANCE_METRIC": "COSINE", - "TYPE": "FLOAT32", - } - ) - - ], prefixes=['json:city:']) - - # or on all hashes that start with a prefix ``city:`` - await client.search.create("hash_index", on=coredis.PureToken.HASH, schema = [ - coredis.modules.search.Field('name', coredis.PureToken.TEXT), - coredis.modules.search.Field('country', coredis.PureToken.TEXT), - coredis.modules.search.Field('population', coredis.PureToken.NUMERIC), - coredis.modules.search.Field("location", coredis.PureToken.GEO), - coredis.modules.search.Field('iso_tags', coredis.PureToken.TAG, separator=","), - coredis.modules.search.Field('summary_vector', coredis.PureToken.VECTOR, - algorithm="FLAT", - attributes={ - "DIM": 768, - "DISTANCE_METRIC": "COSINE", - "TYPE": "FLOAT32", - } - ) - ], prefixes=['city:']) + async with client: + # Create an index on json documents + await client.search.create("json_index", on=coredis.PureToken.JSON, schema = [ + coredis.modules.search.Field('$.name', coredis.PureToken.TEXT, alias='name'), + coredis.modules.search.Field('$.country', coredis.PureToken.TEXT, alias='country'), + coredis.modules.search.Field('$.population', coredis.PureToken.NUMERIC, alias='population'), + coredis.modules.search.Field("$.location", coredis.PureToken.GEO, alias='location'), + coredis.modules.search.Field('$.iso_tags', coredis.PureToken.TAG, alias='iso_tags'), + coredis.modules.search.Field('$.summary_vector', coredis.PureToken.VECTOR, alias='summary_vector', + algorithm="FLAT", + attributes={ + "DIM": 768, + "DISTANCE_METRIC": "COSINE", + "TYPE": "FLOAT32", + } + ) + + ], prefixes=['json:city:']) + + # or on all hashes that start with a prefix ``city:`` + await client.search.create("hash_index", on=coredis.PureToken.HASH, schema = [ + coredis.modules.search.Field('name', coredis.PureToken.TEXT), + coredis.modules.search.Field('country', coredis.PureToken.TEXT), + coredis.modules.search.Field('population', coredis.PureToken.NUMERIC), + coredis.modules.search.Field("location", coredis.PureToken.GEO), + coredis.modules.search.Field('iso_tags', coredis.PureToken.TAG, separator=","), + coredis.modules.search.Field('summary_vector', coredis.PureToken.VECTOR, + algorithm="FLAT", + attributes={ + "DIM": 768, + "DISTANCE_METRIC": "COSINE", + "TYPE": "FLOAT32", + } + ) + ], prefixes=['city:']) To populate the indices we can add some sample city data (a sample that can be used for the above index definition can be found `in the coredis repository `__) using a pipeline for performance:: - pipeline = await client.pipeline() - import requests import numpy @@ -194,26 +191,25 @@ using a pipeline for performance:: "https://raw.githubusercontent.com/alisaifee/coredis/master/tests/modules/data/city_index.json" ).json() - for name, fields in cities.items(): - await pipeline.json.set(f"json:city:{name}", f".", { - "name": name, - "country": fields["country"], - "population": int(fields["population"]), - "location": f"{fields['lng']},{fields['lat']}", - "iso_tags": fields["iso_tags"], - "summary_vector": fields["summary_vector"], - }) - - await pipeline.hset(f"city:{name}", { - "name": name, - "country": fields["country"], - "population": fields["population"], - "location": f"{fields['lng']},{fields['lat']}", - "iso_tags": ",".join(fields["iso_tags"]), - "summary_vector": numpy.asarray(fields["summary_vector"]).astype(numpy.float32).tobytes(), - }) - - await pipeline.execute() + async with client.pipeline(transaction=False) as pipe: + for name, fields in cities.items(): + pipe.json.set(f"json:city:{name}", f".", { + "name": name, + "country": fields["country"], + "population": int(fields["population"]), + "location": f"{fields['lng']},{fields['lat']}", + "iso_tags": fields["iso_tags"], + "summary_vector": fields["summary_vector"], + }) + + pipe.hset(f"city:{name}", { + "name": name, + "country": fields["country"], + "population": fields["population"], + "location": f"{fields['lng']},{fields['lat']}", + "iso_tags": ",".join(fields["iso_tags"]), + "summary_vector": numpy.asarray(fields["summary_vector"]).astype(numpy.float32).tobytes(), + }) .. note:: Take special note of how the ``population`` (numeric field), ``iso_tags`` (tag field) & ``summary_vector`` (vector field) fields are handled differently in the case of hashes vs json documents. @@ -412,23 +408,23 @@ BloomFilter .. code-block:: - import coredis - client = coredis.Redis() - - # create filter - await client.bf.reserve("filter", 0.1, 1000) + import coredis + client = coredis.Redis() + async with client: + # create filter + await client.bf.reserve("filter", 0.1, 1000) - # add items - await client.bf.add("filter", 1) - await client.bf.madd("filter", [2,3,4]) + # add items + await client.bf.add("filter", 1) + await client.bf.madd("filter", [2,3,4]) - # test for inclusion - assert await client.bf.exists("filter", 1) - assert (True, False) == await client.bf.mexists("filter", [2,5]) + # test for inclusion + assert await client.bf.exists("filter", 1) + assert (True, False) == await client.bf.mexists("filter", [2,5]) - # or - assert await coredis.modules.BloomFilter(client).exists("filter", 1) - ... + # or + assert await coredis.modules.BloomFilter(client).exists("filter", 1) + ... For more details refer to the API documentation for :class:`~coredis.modules.BloomFilter` @@ -439,9 +435,6 @@ CuckooFilter .. code-block:: - import coredis - client = coredis.Redis() - # create filter await client.cf.reserve("filter", 1000) @@ -469,9 +462,6 @@ CountMinSketch .. code-block:: - import coredis - client = coredis.Redis() - # create a sketch await client.cms.initbydim("sketch", 2, 50) @@ -490,13 +480,10 @@ TopK .. code-block:: - import coredis import string import itertools import random - client = coredis.Redis() - # create a top-3 await client.topk.reserve("top3", 3) @@ -517,10 +504,6 @@ TDigest .. code-block:: - import coredis - - client = coredis.Redis() - # create a digest await client.tdigest.create("digest") @@ -559,8 +542,9 @@ Create a few timeseries with different labels (:meth:`~modules.TimeSeries.create rooms = {"bedroom", "lounge", "bathroom"} client = coredis.Redis(port=9379) - for room in rooms: - assert await client.timeseries.create(f"temp:{room}", labels={"room": room}) + async with client: + for room in rooms: + assert await client.timeseries.create(f"temp:{room}", labels={"room": room}) Create compaction rules for hourly and daily averages (:meth:`~modules.TimeSeries.createrule`):: @@ -585,13 +569,11 @@ Populate a year of random sample data (:meth:`~modules.TimeSeries.add`):: import random cur = datetime.fromtimestamp(0) - pipeline = await client.pipeline() - while cur < datetime(1971, 1, 1, 0, 0, 0): - cur += timedelta(minutes=random.randint(1, 60)) - for room in rooms: - await pipeline.timeseries.add(f"temp:{room}", cur, random.randint(15, 30)) - - await pipeline.execute() + async with client.pipeline(transaction=True) as pipe: + while cur < datetime(1971, 1, 1, 0, 0, 0): + cur += timedelta(minutes=random.randint(1, 60)) + for room in rooms: + pipe.timeseries.add(f"temp:{room}", cur, random.randint(15, 30)) Query for the latest temperature in each room (:meth:`~modules.TimeSeries.get`):: diff --git a/docs/source/handbook/noreply.rst b/docs/source/handbook/noreply.rst index 5a137498a..568d190ab 100644 --- a/docs/source/handbook/noreply.rst +++ b/docs/source/handbook/noreply.rst @@ -13,12 +13,14 @@ For example:: import coredis client = coredis.Redis(noreply=True) - assert await client.set("fubar", 1) is None - assert await client.hset("hash_fubar", {"a": 1, "b": 2}) is None + async with client: + assert await client.set("fubar", 1) is None + assert await client.hset("hash_fubar", {"a": 1, "b": 2}) is None other_client = coredis.Redis() - assert await other_client.get("fubar") == b"1" - assert await other_client.hgetall("hash_fubar") == {b"a": b"1", b"b": b"2"} + async with other_client: + assert await other_client.get("fubar") == b"1" + assert await other_client.hgetall("hash_fubar") == {b"a": b"1", b"b": b"2"} The mode can also be enabled temporarily through the :meth:`~coredis.Redis.ignore_replies` context manager:: @@ -26,10 +28,10 @@ The mode can also be enabled temporarily through the :meth:`~coredis.Redis.ignor import coredis client = coredis.Redis() - - with client.ignore_replies(): - assert await client.set("fubar", 1) is None - assert await client.get("fubar") == b"1" + async with client: + with client.ignore_replies(): + assert await client.set("fubar", 1) is None + assert await client.get("fubar") == b"1" .. danger:: When the client is used with the the ``noreply`` option there are no guarantees @@ -38,4 +40,4 @@ The mode can also be enabled temporarily through the :meth:`~coredis.Redis.ignor to the socket. .. warning:: Using the ``noreply`` option effectively ignores return annotations - and will (**probably**) therefore fail any type checkers (static or runtime). \ No newline at end of file + and will (**probably**) therefore fail any type checkers (static or runtime). diff --git a/docs/source/handbook/optimization.rst b/docs/source/handbook/optimization.rst index f1a8ceecf..d2b6c70d5 100644 --- a/docs/source/handbook/optimization.rst +++ b/docs/source/handbook/optimization.rst @@ -8,10 +8,10 @@ Optimized mode - Runtime validation of parameter combinations for redis commands that can take various combinations of inputs (examples: :meth:`~coredis.Redis.set` or :meth:`~coredis.Redis.xadd`) - Validation of correct use of iterables as parameters -- Compatibility checks by redis server version +- Compatibility checks by Redis server version Optimized mode can be enabled in any of the following ways: - Set the environment variable :envvar:`COREDIS_OPTIMIZED` to ``true`` - Run Python in optimized mode with :option:`-O` or setting :envvar:`PYTHONOPTIMIZE` -- Explicitly with ``coredis.Config.optimized=True`` \ No newline at end of file +- Explicitly with ``coredis.Config.optimized=True`` diff --git a/docs/source/handbook/pipelines.rst b/docs/source/handbook/pipelines.rst index dd44326f1..d9834ee67 100644 --- a/docs/source/handbook/pipelines.rst +++ b/docs/source/handbook/pipelines.rst @@ -3,15 +3,15 @@ Pipelines Pipelines expose an identical API to :class:`~coredis.Redis`, however the awaitable returned by calling a pipeline method can only be awaited -after the entire pipeline has successfully executed by calling -:meth:`~coredis.pipeline.Pipeline.execute` +after the entire pipeline has successfully executed, that is, after +exiting the pipeline's async context manager: For example: .. code-block:: python async def example(client): - async with await client.pipeline(transaction=True) as pipe: + async with client.pipeline(transaction=True) as pipe: # commands is a tuple of awaitables commands = ( pipe.flushdb(), @@ -19,11 +19,9 @@ For example: pipe.set("bar", "foo"), pipe.keys("*"), ) - results = await pipe.execute() - # results are in order corresponding to your command - assert results == (True, True, True, set([b"bar", b"foo"])) - # results can also be retrieved from the returns of each command - assert await asyncio.gather(*commands) == (True, True, True, set[b"bar", b"foo"]) + # results can be retrieved from the returns of each command + # notice this is OUTSIDE of the pipeline block + assert await asyncio.gather(*commands) == (True, True, True, {b"bar", b"foo"}) Atomicity & Transactions @@ -54,9 +52,9 @@ could do something like this: .. code-block:: python async def example(): - async with await r.pipeline() as pipe: - while True: - try: + while True: + try: + async with r.pipeline(transaction=False) as pipe: # put a WATCH on the key that holds our sequence value await pipe.watch("OUR-SEQUENCE-KEY") # after WATCHing, the pipeline is put into immediate execution @@ -68,16 +66,15 @@ could do something like this: pipe.multi() # This call doesn't need to be awaited as it is part of the pipeline pipe.set("OUR-SEQUENCE-KEY", next_value) - # and finally, execute the pipeline (the set command) - await pipe.execute() - # if a WatchError wasn"t raised during execution, everything - # we just did happened atomically. - break - except WatchError: - # another client must have changed "OUR-SEQUENCE-KEY" between - # the time we started WATCHing it and the pipeline"s execution. - # our best bet is to just retry. - continue + except WatchError: + # another client must have changed "OUR-SEQUENCE-KEY" between + # the time we started WATCHing it and the pipeline"s execution. + # our best bet is to just retry. + continue + else: + # if a WatchError wasn"t raised during execution, everything + # we just did happened atomically. + break Note that, because the Pipeline must bind to a single connection for the duration of a :rediscommand:`WATCH`, care must be taken to ensure that the connection is @@ -89,36 +86,15 @@ explicitly calling :meth:`~coredis.pipeline.Pipeline.clear`: .. code-block:: python async def example(): - async with await r.pipeline() as pipe: - while 1: - try: + while 1: + try: + async with r.pipeline() as pipe: await pipe.watch("OUR-SEQUENCE-KEY") ... - await pipe.execute() - break - except WatchError: - continue - finally: - await pipe.clear() - -A convenience method :meth:`~coredis.Redis.transaction` exists for handling all the -boilerplate of handling and retrying watch errors. It takes a callable that -should expect a single parameter, a pipeline object, and any number of keys to -be watched. Our client-side :rediscommand:`INCR` command above can be written like this, -which is much easier to read: - -.. code-block:: python - - async def client_side_incr(pipe) -> int: - current_value = await pipe.get("OUR-SEQUENCE-KEY") or 0 - next_value = int(current_value) + 1 - pipe.multi() - await pipe.set("OUR-SEQUENCE-KEY", next_value) - return next_value - - await r.transaction(client_side_incr, "OUR-SEQUENCE-KEY") - # (True,) - await r.transaction(client_side_incr, "OUR-SEQUENCE-KEY", value_from_callable=True) - # 2 - - + pipe.multi() + pipe.set(...) + ... + except WatchError: + continue + else: + break diff --git a/docs/source/handbook/pubsub.rst b/docs/source/handbook/pubsub.rst index b30a2ad0a..097aeae13 100644 --- a/docs/source/handbook/pubsub.rst +++ b/docs/source/handbook/pubsub.rst @@ -19,40 +19,18 @@ or :meth:`~coredis.commands.PubSub.psubscribe` methods. Upon instantiation:: - consumer = await client.pubsub( + async with client.pubsub( channels=["my-first-channel", "my-second-channel"], patterns=["my-*"] - ) - -.. note:: If the newly created pubsub instance can't be awaited because - it is done in a synchronous context, the initial subscriptions will occur - on the first async call to the instance. If explicit initialization is preferred - the instance can be awaited when the async context is available or through a call - to :meth:`~coredis.commands.PubSub.initialize`. - - For example:: - - consumer = client.pubsub( - channels=["my-first-channel", "my-second-channel"], - patterns=["my-*"] - ) - assert not consumer.subscribed - # later in an async context - await consumer - # or - await consumer.initialize() - # or simply use the instance - await consumer.get_message() - + ) as consumer: + ... or explicitly:: - consumer = client.pubsub() - await consumer.subscribe("my-first-channel", "my-second-channel", ...) - await consumer.psubscribe("my-*") - + async with client.pubsub() as consumer: + await consumer.subscribe("my-first-channel", "my-second-channel", ...) + await consumer.psubscribe("my-*") -The recommended way of using a pubsub instance is with the async context manager -which automatically manages unsubscribing and connection cleanup on exit:: +The async context manager automatically manages unsubscribing and cleanup on exit:: async with client.pubsub( channels=["my-first-channel", "my-second-channel"], patterns=["my-*"] @@ -62,8 +40,6 @@ which automatically manages unsubscribing and connection cleanup on exit:: # remaining subscriptions are unsubscribed and connection is released # back to the connection pool when the context manager exits. - - If desired unsubscription can also be done explicitly by calling :meth:`~coredis.commands.PubSub.unsubscribe` for channels and :meth:`~coredis.commands.PubSub.punsubscribe` for patterns. @@ -82,8 +58,6 @@ exit when the consumer has no subscriptions):: else: print(message["data"]) - - Consuming Messages ^^^^^^^^^^^^^^^^^^ @@ -101,12 +75,10 @@ will be a typed dictionary defined as: With the iterator:: - consumer.subscribe("my-channel") - async for message in consumer.messages: + await consumer.subscribe("my-channel") + async for message in consumer: # do something with the message - - .. note:: Unsubscribing from all subscribed channels will result in the iterator ending (i.e. raising :exc:`StopAsyncIteration`) @@ -165,18 +137,7 @@ PubSub instances remember what channels and patterns they are subscribed to. In the event of a disconnection such as a network error or timeout, the PubSub instance will re-subscribe to all prior channels and patterns when reconnecting. Messages that were published while the client was disconnected -cannot be delivered. When you're finished with a PubSub object, call the -:meth:`~coredis.commands.PubSub.aclose` method to shutdown the connection and unsubscribe. - -.. note:: This isn't necessary if using the pubsub instance with the async context manager - since that automatically calls :meth:`~coredis.commands.PubSub.aclose` when the context - manager exits. - -.. code-block:: python - - consumer = client.pubsub() - ... - await consumer.aclose() +cannot be delivered. The Pub/Sub support commands :rediscommand:`PUBSUB-CHANNELS`, :rediscommand:`PUBSUB-NUMSUB` and :rediscommand:`PUBSUB-NUMPAT` are also supported: @@ -232,4 +193,4 @@ can use a dedicated connection per node to drain messages). Additionally, the :paramref:`~coredis.RedisCluster.sharded_pubsub.read_from_replicas` parameter can be set to ``True`` when constructing a :class:`~coredis.commands.pubsub.ShardedPubSub` instance -to further increase throughput by letting the consumer use read replicas. \ No newline at end of file +to further increase throughput by letting the consumer use read replicas. diff --git a/docs/source/handbook/response.rst b/docs/source/handbook/response.rst deleted file mode 100644 index a985a038a..000000000 --- a/docs/source/handbook/response.rst +++ /dev/null @@ -1,18 +0,0 @@ -Redis Response --------------- - -As of redis `6.0.0` clients can use the -:term:`RESP3` protocol which provides support for a much larger set of types (which reduces the need for clients -to "guess" what the type of a command's response should be). -**coredis** provides backward compatibility for ``RESP`` -and the structure of responses from coredis is consistent -between :term:`RESP` (``protocol_version=2``) and :term:`RESP3` (``protocol_version=3``) protocols. - -To fallback to ``RESP`` the :paramref:`~coredis.Redis.protocol_version` constructor parameter -can be set to ``2``. - -.. code-block:: python - - r = coredis.Redis(protocol_version=2) - - diff --git a/docs/source/handbook/scripting.rst b/docs/source/handbook/scripting.rst index fe588cdc4..c420063c2 100644 --- a/docs/source/handbook/scripting.rst +++ b/docs/source/handbook/scripting.rst @@ -218,13 +218,14 @@ the key/argument mapping behavior. This can now be used as you would expect:: client = coredis.Redis() - lib = await MyLib(client, replace=True) - await lib.ping() - # b"pong" - await lib.echo("hello world") - # b"hello world" - await client.hset("k1", {"a": 10, "b": 20}) - await client.hset("k2", {"c": 30, "d": 40}) - - await lib.hmmget("k1", "k2", a=1, b=2, c=3, d=4, e=5, f=6) - # [b"10", b"20", b"30", b"40", b"5", b"6"] + async with client: + lib = await MyLib(client, replace=True) + await lib.ping() + # b"pong" + await lib.echo("hello world") + # b"hello world" + await client.hset("k1", {"a": 10, "b": 20}) + await client.hset("k2", {"c": 30, "d": 40}) + + await lib.hmmget("k1", "k2", a=1, b=2, c=3, d=4, e=5, f=6) + # [b"10", b"20", b"30", b"40", b"5", b"6"] diff --git a/docs/source/handbook/sentinel.rst b/docs/source/handbook/sentinel.rst index 71eb6799a..3c9810f27 100644 --- a/docs/source/handbook/sentinel.rst +++ b/docs/source/handbook/sentinel.rst @@ -12,10 +12,11 @@ Sentinel connection to discover the primary and replicas network addresses: from coredis.sentinel import Sentinel sentinel = Sentinel([('localhost', 26379)], stream_timeout=0.1) - await sentinel.discover_primary('myredis') - # ('127.0.0.1', 6379) - await sentinel.discover_replicas('myredis') - # [('127.0.0.1', 6380)] + async with sentinel: + await sentinel.discover_primary('myredis') + # ('127.0.0.1', 6379) + await sentinel.discover_replicas('myredis') + # [('127.0.0.1', 6380)] You can also create Redis client connections from a Sentinel instance. You can connect to either the primary (for write operations) or a replica (for read-only @@ -25,9 +26,10 @@ operations). primary = sentinel.primary_for('myredis', stream_timeout=0.1) replica = sentinel.replica_for('myredis', stream_timeout=0.1) - primary.set('foo', 'bar') - replica.get('foo') - # 'bar' + async with primary, replica: + await primary.set('foo', 'bar') + await replica.get('foo') + # 'bar' The primary and replica objects are normal :class:`~coredis.Redis` instances with their connection pool bound to the Sentinel instance via :class:`~coredis.sentinel.SentinelConnectionPool`. diff --git a/docs/source/handbook/typing.rst b/docs/source/handbook/typing.rst index 7bab68c66..68c1ef7b3 100644 --- a/docs/source/handbook/typing.rst +++ b/docs/source/handbook/typing.rst @@ -88,5 +88,3 @@ As an example: Traceback (most recent call last): File "<@beartype(coredis.commands.core.CoreCommands.set) at 0x10c403130>", line 33, in set beartype.roar.BeartypeCallHintParamViolation: @beartyped coroutine CoreCommands.set() parameter key=1 violates type hint typing.Union[str, bytes], as 1 not str or bytes. - - diff --git a/docs/source/history.rst b/docs/source/history.rst index 2d6b2cbe0..c39610d3e 100644 --- a/docs/source/history.rst +++ b/docs/source/history.rst @@ -9,6 +9,8 @@ performing async python clients. Since it had become unmaintained as of October The initial intention of the fork was add python 3.10 compatibility and `coredis 2.x `__ is drop-in backward compatible with **aredis** and adds support up to python 3.10. +In August 2025, `Graeme Holliday `_ opened a PR that +would eventually restructure coredis to use structured concurrency and add Trio support. Divergence from aredis & redis-py --------------------------------- @@ -26,6 +28,8 @@ client, this inherently means that **coredis** diverges from both, most notable .. automethod:: coredis.Redis.expire :noindex: +- Type hints are significantly better than redis-py's, which are terrible for the async client, + and maintainers have indicated they don't care to address the problem Default RESP3 ------------- @@ -34,7 +38,6 @@ Default RESP3 from the redis server and defaulted to the legacy ``RESP`` protocol. Since **coredis** has dropped support for redis server versions below ``6.0`` the default protocol version is now :term:`RESP3`. - Parsers ------- **coredis** versions ``2.x`` and ``3.x`` would default to a :pypi:`hiredis` based parser if the diff --git a/docs/source/index.rst b/docs/source/index.rst index 66cf330a0..f437fe506 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -26,13 +26,14 @@ coredis :alt: Code coverage :class: header-badge -coredis is an async redis client with support for redis server, cluster & sentinel. +Fast, async, fully-typed Redis client with support for cluster and sentinel + The client API uses the specifications in the Redis command documentation to define the API by using the following conventions: The coredis :ref:`api/clients:clients` use the specifications in the `Redis command documentation `__ to define the API by using the following conventions: -- Arguments retain naming from redis as much as possible +- Arguments retain naming from Redis as much as possible - **Only** optional variadic arguments are mapped to position or keyword variadic arguments. When the variable length arguments are not optional the expected argument is an iterable of type :class:`~coredis.typing.Parameters` or :class:`~typing.Mapping`. @@ -50,7 +51,7 @@ Feature Summary * :class:`~coredis.Redis` * :class:`~coredis.RedisCluster` - * :class:`~coredis.sentinel.Sentinel` + * :class:`~coredis.Sentinel` * Application patterns @@ -106,21 +107,23 @@ Single Node or Cluster client client = Redis(host='127.0.0.1', port=6379, db=0) # or with redis cluster # client = RedisCluster(startup_nodes=[{"host": "127.0.01", "port": 7001}]) - await client.flushdb() - await client.set('foo', 1) - assert await client.exists(['foo']) == 1 - assert await client.incr('foo') == 2 - assert await client.incrby('foo', increment=100) == 102 - assert int(await client.get('foo')) == 102 - - assert await client.expire('foo', 1) - await asyncio.sleep(0.1) - assert await client.ttl('foo') == 1 - assert await client.pttl('foo') < 1000 - await asyncio.sleep(1) - assert not await client.exists(['foo']) + async with client: + await client.flushdb() + await client.set('foo', 1) + assert await client.exists(['foo']) == 1 + assert await client.incr('foo') == 2 + assert await client.incrby('foo', increment=100) == 102 + assert int(await client.get('foo')) == 102 + + assert await client.expire('foo', 1) + await asyncio.sleep(0.1) + assert await client.ttl('foo') == 1 + assert await client.pttl('foo') < 1000 + await asyncio.sleep(1) + assert not await client.exists(['foo']) asyncio.run(example()) + # OR trio.run(example()) Sentinel -------- @@ -132,11 +135,13 @@ Sentinel async def example(): sentinel = Sentinel(sentinels=[("localhost", 26379)]) - primary = sentinel.primary_for("myservice") - replica = sentinel.replica_for("myservice") + async with sentinel: + primary = sentinel.primary_for("myservice") + replica = sentinel.replica_for("myservice") - assert await primary.set("fubar", 1) - assert int(await replica.get("fubar")) == 1 + async with primary, replica: + assert await primary.set("fubar", 1) + assert int(await replica.get("fubar")) == 1 asyncio.run(example()) @@ -147,24 +152,10 @@ Compatibility **coredis** is tested against redis versions >= ``7.0`` The test matrix status can be reviewed `here `__ -.. note:: Though **coredis** officially only supports :redis-version:`6.0.0` and above it is known to work with lower - versions. - - A known compatibility issue with older redis versions is the lack of support for :term:`RESP3` and - the :rediscommand:`HELLO` command. The default :class:`~coredis.Redis` and :class:`~coredis.RedisCluster` clients - do not work in this scenario as the :rediscommand:`HELLO` command is used for initial handshaking to confirm that - the default ``RESP3`` protocol version can be used and to perform authentication if necessary. - - This can be worked around by passing ``2`` to :paramref:`coredis.Redis.protocol_version` to downgrade to :term:`RESP` - (see :ref:`handbook/response:redis response`). - - When using :term:`RESP` **coredis** will also fall back to the legacy :rediscommand:`AUTH` command if the - :rediscommand:`HELLO` is not supported. - - coredis is additionally tested against: -- :pypi:`uvloop` >= `0.15.0`. +- :pypi:`uvloop` >= `0.15.0` +- :pypi:`trio` Supported python versions ------------------------- diff --git a/docs/source/recipes/credentials.rst b/docs/source/recipes/credentials.rst index e9bbb2164..249f60c21 100644 --- a/docs/source/recipes/credentials.rst +++ b/docs/source/recipes/credentials.rst @@ -7,7 +7,7 @@ Elasticache IAM Credential Provider The implementation is based on `the Elasticache IAM provider described in redis docs `__ -The :class:`~coredis.recipes.credentials.ElastiCacheIAMProvider` implements the +The :class:`~coredis.recipes.ElastiCacheIAMProvider` implements the :class:`~coredis.credentials.AbstractCredentialProvider` interface. It uses :pypi:`aiobotocore` to generate a short-lived authentication token which can be used to authenticate with an IAM enabled Elasticache cluster. @@ -17,7 +17,6 @@ of unnecessary requests. See https://docs.aws.amazon.com/AmazonElastiCache/latest/dg/auth-iam.html for more details on using IAM to authenticate with Elasticache. -.. autoclass:: coredis.recipes.credentials.ElastiCacheIAMProvider +.. autoclass:: coredis.recipes.ElastiCacheIAMProvider :class-doc-from: both :no-index: - diff --git a/docs/source/recipes/locks.rst b/docs/source/recipes/locks.rst index b19051897..92d17df25 100644 --- a/docs/source/recipes/locks.rst +++ b/docs/source/recipes/locks.rst @@ -1,6 +1,6 @@ Locks ----- -:mod:`coredis.recipes.locks` +:mod:`coredis.recipes.lock` Distributed lock with LUA Scripts ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -8,7 +8,7 @@ Distributed lock with LUA Scripts The implementation is based on `the distributed locking pattern described in redis docs `__ When used with a :class:`~coredis.RedisCluster` instance, acquiring the lock includes -ensuring that the token set by the :meth:`~coredis.recipes.locks.LuaLock.acquire` method +ensuring that the token set by the :meth:`~coredis.recipes.Lock.acquire` method is replicated to atleast ``n/2`` replicas using the :meth:`~coredis.RedisCluster.ensure_replication` context manager. @@ -16,11 +16,11 @@ The implementation uses the following LUA scripts: #. Release the lock - .. literalinclude:: ../../../coredis/recipes/locks/release.lua + .. literalinclude:: ../../../coredis/recipes/lua/release.lua + #. Extend the lock - .. literalinclude:: ../../../coredis/recipes/locks/extend.lua + .. literalinclude:: ../../../coredis/recipes/lua/extend.lua -.. autoclass:: coredis.recipes.locks.LuaLock +.. autoclass:: coredis.recipes.Lock :class-doc-from: both - diff --git a/pyproject.toml b/pyproject.toml index 9ad3b34e7..9c63ff54b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,20 +1,19 @@ [build-system] build-backend = "hatchling.build" requires = [ - "async_timeout>4,<6", + "anyio>=4.11.0", "hatchling>=1.14.0", "hatch-mypyc>=0.1.0", "hatch-vcs>=0.4.0", "mypy==1.18.1", "types-deprecated", - "pympler>1,<2", "beartype>=0.20", ] [project] name = "coredis" dynamic = ["version"] -description = "Python async client for Redis key-value store" +description = "Fast, async, fully-typed Redis client with support for cluster and sentinel" readme = "README.md" license = "MIT" license-files = ["LICENSE"] @@ -22,29 +21,40 @@ authors = [ {name = "Ali-Akber Saifee", email = "ali@indydevs.org"} ] maintainers = [ - {name = "Ali-Akber Saifee", email = "ali@indydevs.org"} + {name = "Ali-Akber Saifee", email = "ali@indydevs.org"}, + {name = "Graeme Holliday", email = "graeme@tastyware.dev"} ] -keywords = ["Redis", "key-value store", "asyncio"] +keywords = ["Redis", "key-value store", "asyncio", "Trio", "anyio"] classifiers = [ "Development Status :: 5 - Production/Stable", + "Framework :: AsyncIO", + "Framework :: AnyIO", + "Framework :: Trio", "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", "Operating System :: OS Independent", "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: System :: Distributed Computing", + "Typing :: Typed", ] requires-python = ">=3.10" dependencies = [ - "async_timeout>4,<6", + "anyio>=4.11.0", "beartype>=0.20", "deprecated>=1.2", "typing_extensions>=4.13", "packaging>=21,<26", - "pympler>1,<2", + "exceptiongroup>=1.3.0", ] [project.optional-dependencies] @@ -71,7 +81,9 @@ test = [ "redis", "aiobotocore>=2.15.2", "asyncache>=0.3.1", - "moto" + "moto", + "trio>=0.31.0", + "uvloop; platform.python_implementation != 'PyPy'", ] dev = [ @@ -90,6 +102,7 @@ dev = [ ci = [ "pytest-rerunfailures", "pytest-sentry", + "pytest-timeout", {include-group = "dev"}, ] @@ -109,6 +122,10 @@ docs = [ {include-group = "dev"}, ] +orjson = [ + "orjson" +] + [project.urls] Homepage = "https://github.com/alisaifee/coredis" Source = "https://github.com/alisaifee/coredis" @@ -178,14 +195,12 @@ exclude = ["coredis/_py_312_typing.py"] [[tool.mypy.overrides]] module = [ - "async_timeout", "beartype", "asyncache", "aiobotocore.*", "botocore.*", "cachetools", "deprecated", - "pympler", ] ignore_errors = true ignore_missing_imports = true @@ -220,4 +235,3 @@ MYPYC_OPT_LEVEL = "3" [tool.cibuildwheel.linux.environment] HATCH_BUILD_HOOKS_ENABLE = "1" MYPYC_OPT_LEVEL = "3" - diff --git a/pytest.ini b/pytest.ini index ec51d02df..2e0413861 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] +anyio_mode = auto testpaths = tests addopts = --verbose @@ -6,7 +7,6 @@ addopts = --capture=no -rfE -K -asyncio_mode = auto markers = auth basic @@ -26,8 +26,6 @@ markers = nocluster noredict noreplica - noresp3 - resp2 clusteronly replicated_clusteronly cached diff --git a/scripts/code_gen.py b/scripts/code_gen.py index 8e5a63b35..a3dc265c6 100644 --- a/scripts/code_gen.py +++ b/scripts/code_gen.py @@ -20,14 +20,13 @@ import coredis import coredis.client import coredis.pipeline +from coredis._utils import b from coredis.commands.constants import * # noqa -from coredis.commands.monitor import Monitor from coredis.globals import CACHEABLE_COMMANDS from coredis.pool import ClusterConnectionPool, ConnectionPool # noqa from coredis.response.types import * # noqa from coredis.tokens import PureToken # noqa from coredis.typing import * # noqa -from coredis._utils import b MAX_SUPPORTED_VERSION = version.parse("8.999.999") MIN_SUPPORTED_VERSION = version.parse("5.999.999") @@ -263,7 +262,6 @@ "MEMORY STATS": dict[AnyStr, AnyStr | int | float], "MGET": tuple[AnyStr | None, ...], "MODULE LIST": tuple[dict, ...], - "MONITOR": Monitor, "PING": AnyStr, "PFADD": bool, "PSETEX": bool, @@ -1608,7 +1606,7 @@ async def {{method["name"]}}{{render_signature(method["rec_signature"])}}: debug=debug, sanitized=sanitized, getattr=getattr, - b=b + b=b, ) section_template = env.from_string(section_template_str) methods_by_group = {} diff --git a/tests/cluster/conftest.py b/tests/cluster/conftest.py deleted file mode 100644 index f9b50063a..000000000 --- a/tests/cluster/conftest.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import pytest - -import coredis - - -@pytest.fixture -def s(redis_cluster_server): - cluster = coredis.RedisCluster( - startup_nodes=[{"host": "localhost", "port": 7000}], decode_responses=True - ) - assert cluster.connection_pool.nodes.slots == {} - assert cluster.connection_pool.nodes.nodes == {} - - yield cluster - - cluster.connection_pool.disconnect() - - -@pytest.fixture -def sr(redis_cluster_server): - cluster = coredis.RedisCluster( - startup_nodes=[{"host": "localhost", "port": 7000}], - reinitialize_steps=1, - decode_responses=True, - ) - yield cluster - - cluster.connection_pool.disconnect() - - -@pytest.fixture -def ro(redis_cluster_server): - cluster = coredis.RedisCluster( - startup_nodes=[{"host": "localhost", "port": 7000}], - read_from_replicas=True, - decode_responses=True, - ) - yield cluster - - cluster.connection_pool.disconnect() - - -@pytest.fixture(autouse=True) -def cluster(redis_cluster_server): - pass diff --git a/tests/cluster/test_cluster_connection_pool.py b/tests/cluster/test_cluster_connection_pool.py deleted file mode 100644 index 0339c190a..000000000 --- a/tests/cluster/test_cluster_connection_pool.py +++ /dev/null @@ -1,532 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -from collections import deque -from unittest.mock import Mock, patch - -import pytest - -from coredis import Redis -from coredis.connection import ClusterConnection, Connection, UnixDomainSocketConnection -from coredis.exceptions import ConnectionError, RedisClusterException -from coredis.parser import Parser -from coredis.pool import ClusterConnectionPool, ConnectionPool -from coredis.pool.nodemanager import ManagedNode -from tests.conftest import targets - - -class DummyConnection(ClusterConnection): - description_format = "DummyConnection<>" - - def __init__(self, host="localhost", port=7000, socket_timeout=None, **kwargs): - self.kwargs = kwargs - self.pid = os.getpid() - self.host = host - self.port = port - self.socket_timeout = socket_timeout - self.awaiting_response = False - self._parser = Parser() - self._last_error = None - self._transport = None - self._read_flag = asyncio.Event() - self._read_waiters = set() - self._description_args = lambda: {} - self._parse_task = None - self._requests = deque() - self.average_response_time = 0 - self.requests_processed = 0 - - -class TestConnectionPool: - async def get_pool( - self, - connection_kwargs=None, - max_connections=None, - max_connections_per_node=None, - connection_class=DummyConnection, - blocking=False, - timeout=0, - ): - connection_kwargs = connection_kwargs or {} - pool = ClusterConnectionPool( - connection_class=connection_class, - max_connections=max_connections, - max_connections_per_node=max_connections_per_node, - startup_nodes=[{"host": "127.0.0.1", "port": 7000}], - blocking=blocking, - timeout=timeout, - **connection_kwargs, - ) - await pool.initialize() - - return pool - - async def test_no_available_startup_nodes(self, redis_cluster): - pool = ClusterConnectionPool( - startup_nodes=[{"host": "foo", "port": 6379}, {"host": "bar", "port": 6379}] - ) - with pytest.raises(RedisClusterException, match="Redis Cluster cannot be connected"): - await pool.initialize() - with pytest.raises(RedisClusterException, match="Cant reach a single startup node"): - await pool.get_connection_by_slot(1) - with pytest.raises(RedisClusterException, match="Cant reach a single startup node"): - await pool.get_random_connection() - - async def test_in_use_not_exists(self, redis_cluster): - """ - Test that if for some reason, the node that it tries to get the connectino for - do not exists in the _in_use_connection variable. - """ - pool = await self.get_pool() - pool._in_use_connections = {} - await pool.get_connection(b"pubsub", channel="foobar") - - async def test_connection_creation(self, redis_cluster): - connection_kwargs = {"foo": "bar", "biz": "baz"} - pool = await self.get_pool(connection_kwargs=connection_kwargs) - connection = await pool.get_connection_by_node( - ManagedNode(**{"host": "127.0.0.1", "port": 7000}) - ) - assert isinstance(connection, DummyConnection) - - for key in connection_kwargs: - assert connection.kwargs[key] == connection_kwargs[key] - - async def test_multiple_connections(self, redis_cluster): - pool = await self.get_pool() - c1 = await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - c2 = await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7001})) - assert c1 != c2 - - async def test_max_connections_too_low(self, redis_cluster): - with pytest.warns(UserWarning, match="increased by 4 connections"): - pool = await self.get_pool(max_connections=2) - assert pool.max_connections == 6 - - async def test_max_connections(self, redis_cluster): - pool = await self.get_pool(max_connections=6) - for port in range(7000, 7006): - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": port})) - with pytest.raises(ConnectionError): - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - - async def test_max_connections_blocking(self, redis_cluster): - pool = await self.get_pool(max_connections=6, blocking=True, timeout=1) - connections = [] - for port in range(7000, 7006): - connections.append( - await pool.get_connection_by_node( - ManagedNode(**{"host": "127.0.0.1", "port": port}) - ) - ) - with pytest.raises(ConnectionError): - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - pool.release(connections[0]) - assert connections[0] == await pool.get_connection_by_node( - ManagedNode(**{"host": "127.0.0.1", "port": 7000}) - ) - - async def test_max_connections_per_node(self, redis_cluster): - pool = await self.get_pool(max_connections=2, max_connections_per_node=True) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7001})) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7001})) - with pytest.raises(ConnectionError): - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - - async def test_max_connections_per_node_blocking(self, redis_cluster): - pool = await self.get_pool( - max_connections=2, max_connections_per_node=True, blocking=True, timeout=1 - ) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7001})) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7001})) - with pytest.raises(ConnectionError): - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - - async def test_max_connections_default_setting(self): - pool = await self.get_pool(max_connections=None) - assert pool.max_connections == 2**31 - - async def test_pool_disconnect(self): - pool = await self.get_pool() - c1 = await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - c2 = await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7001})) - c3 = await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - pool.release(c3) - pool.disconnect() - assert not c1.is_connected - assert not c2.is_connected - assert not c3.is_connected - - async def test_reuse_previously_released_connection(self): - pool = await self.get_pool() - c1 = await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - pool.release(c1) - c2 = await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - assert c1 == c2 - - async def test_repr_contains_db_info_tcp(self, host_ip): - """ - Note: init_slot_cache muts be set to false otherwise it will try to - query the test server for data and then it can't be predicted reliably - """ - connection_kwargs = {"host": "127.0.0.1", "port": 7000} - pool = await self.get_pool( - connection_kwargs=connection_kwargs, connection_class=ClusterConnection - ) - expected = f"ClusterConnection" - assert expected in repr(pool) - - async def test_get_connection_by_key(self): - """ - This test assumes that when hashing key 'foo' will be sent to server with port 7002 - """ - pool = await self.get_pool(connection_kwargs={}) - - # Patch the call that is made inside the method to allow control of the returned - # connection object - with patch.object( - ClusterConnectionPool, "get_connection_by_slot", autospec=True - ) as pool_mock: - - async def side_effect(self, *args, **kwargs): - return DummyConnection(port=1337) - - pool_mock.side_effect = side_effect - - connection = await pool.get_connection_by_key("foo") - assert connection.port == 1337 - - with pytest.raises(RedisClusterException) as ex: - await pool.get_connection_by_key(None) - assert str(ex.value).startswith("No way to dispatch this command to Redis Cluster."), True - - async def test_get_connection_by_slot(self): - """ - This test assumes that when doing keyslot operation on "foo" it will return 12182 - """ - pool = await self.get_pool(connection_kwargs={}) - - # Patch the call that is made inside the method to allow control of the returned - # connection object - with patch.object( - ClusterConnectionPool, "get_connection_by_node", autospec=True - ) as pool_mock: - - async def side_effect(self, *args, **kwargs): - return DummyConnection(port=1337) - - pool_mock.side_effect = side_effect - - connection = await pool.get_connection_by_slot(12182) - assert connection.port == 1337 - - class AsyncMock(Mock): - def __await__(self): - future = asyncio.Future(loop=asyncio.get_event_loop()) - future.set_result(self) - result = yield from future - - return result - - m = AsyncMock() - pool.get_random_connection = m - - # If None value is provided then a random node should be tried/returned - await pool.get_connection_by_slot(None) - m.assert_called_once_with() - - async def test_get_connection_blocked(self): - """ - Currently get_connection() should only be used by pubsub command. - All other commands should be blocked and exception raised. - """ - pool = await self.get_pool() - - with pytest.raises(RedisClusterException) as ex: - await pool.get_connection("GET") - assert str(ex.value).startswith("Only 'pubsub' commands can use get_connection()") - - async def test_master_node_by_slot(self): - pool = await self.get_pool(connection_kwargs={}) - node = pool.get_primary_node_by_slot(0) - node.port = 7000 - node = pool.get_primary_node_by_slot(12182) - node.port = 7002 - - async def test_connection_idle_check(self): - pool = ClusterConnectionPool( - startup_nodes=[dict(host="127.0.0.1", port=7000)], - max_idle_time=0.2, - idle_check_interval=0.1, - ) - await pool.initialize() - conn = await pool.get_connection_by_node( - ManagedNode( - **{ - "host": "127.0.0.1", - "port": 7000, - "server_type": "primary", - } - ) - ) - name = conn.node.name - assert len(pool._cluster_in_use_connections[name]) == 1 - pool.release(conn) - assert len(pool._cluster_in_use_connections[name]) == 0 - assert pool._cluster_available_connections[name].qsize() == 1 - await asyncio.sleep(0.3) - assert len(pool._cluster_in_use_connections[name]) == 0 - last_active_at = conn.last_active_at - assert last_active_at == conn.last_active_at - assert conn._transport is None - - @targets( - "redis_cluster", - ) - async def test_coverage_check_fail(self, client, user_client, _s): - with pytest.warns( - UserWarning, - match="Unable to determine whether the cluster requires full coverage", - ): - no_perm_client = await user_client("testuser", "on", "+@all", "-CONFIG") - assert _s("PONG") == await no_perm_client.ping() - - -class TestReadOnlyConnectionPool: - async def get_pool(self, connection_kwargs=None, max_connections=None, startup_nodes=None): - startup_nodes = startup_nodes or [{"host": "127.0.0.1", "port": 7000}] - connection_kwargs = connection_kwargs or {} - pool = ClusterConnectionPool( - max_connections=max_connections, - startup_nodes=startup_nodes, - read_from_replicas=True, - **connection_kwargs, - ) - await pool.initialize() - - return pool - - async def test_repr_contains_db_info_readonly(self, host_ip): - """ - Note: init_slot_cache must be set to false otherwise it will try to - query the test server for data and then it can't be predicted reliably - """ - pool = await self.get_pool( - startup_nodes=[ - {"host": "127.0.0.1", "port": 7000}, - {"host": "127.0.0.2", "port": 7001}, - ], - ) - assert f"ClusterConnection" in repr(pool) - assert f"ClusterConnection" in repr(pool) - - async def test_max_connections(self): - pool = await self.get_pool(max_connections=6) - for port in range(7000, 7006): - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": port})) - with pytest.raises(ConnectionError): - await pool.get_connection_by_node(ManagedNode(**{"host": "127.0.0.1", "port": 7000})) - - -class TestConnectionPoolURLParsing: - def test_defaults(self): - pool = ConnectionPool.from_url("redis://localhost") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": None, - "password": None, - } - - def test_hostname(self): - pool = ConnectionPool.from_url("redis://myhost") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "myhost", - "port": 6379, - "db": 0, - "username": None, - "password": None, - } - - def test_quoted_hostname(self): - pool = ConnectionPool.from_url("redis://my %2F host %2B%3D+", decode_components=True) - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "my / host +=+", - "port": 6379, - "db": 0, - "username": None, - "password": None, - } - - def test_port(self): - pool = ConnectionPool.from_url("redis://localhost:6380") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6380, - "db": 0, - "username": None, - "password": None, - } - - def test_password(self): - pool = ConnectionPool.from_url("redis://:mypassword@localhost") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": "", - "password": "mypassword", - } - - def test_quoted_password(self): - pool = ConnectionPool.from_url( - "redis://:%2Fmypass%2F%2B word%3D%24+@localhost", decode_components=True - ) - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": None, - "password": "/mypass/+ word=$+", - } - - def test_quoted_path(self): - pool = ConnectionPool.from_url( - "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket", - decode_components=True, - ) - assert pool.connection_class == UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/my/path/to/../+_+=$ocket", - "db": 0, - "username": None, - "password": "mypassword", - } - - def test_db_as_argument(self): - pool = ConnectionPool.from_url("redis://localhost", db="1") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 1, - "username": None, - "password": None, - } - - def test_db_in_path(self): - pool = ConnectionPool.from_url("redis://localhost/2", db="1") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 2, - "username": None, - "password": None, - } - - def test_db_in_querystring(self): - pool = ConnectionPool.from_url("redis://localhost/2?db=3", db="1") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 3, - "username": None, - "password": None, - } - - def test_extra_querystring_options(self): - pool = ConnectionPool.from_url("redis://localhost?a=1&b=2") - assert pool.connection_class == Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": None, - "password": None, - "a": "1", - "b": "2", - } - - def test_client_creates_connection_pool(self): - r = Redis.from_url("redis://myhost") - assert r.connection_pool.connection_class == Connection - assert r.connection_pool.connection_kwargs == { - "host": "myhost", - "port": 6379, - "db": 0, - "decode_responses": False, - "protocol_version": 3, - "username": None, - "password": None, - "noreply": False, - "noevict": False, - "notouch": False, - } - - -class TestConnectionPoolUnixSocketURLParsing: - def test_defaults(self): - pool = ConnectionPool.from_url("unix:///socket") - assert pool.connection_class == UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 0, - "username": None, - "password": None, - } - - def test_password(self): - pool = ConnectionPool.from_url("unix://:mypassword@/socket") - assert pool.connection_class == UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 0, - "username": "", - "password": "mypassword", - } - - def test_db_as_argument(self): - pool = ConnectionPool.from_url("unix:///socket", db=1) - assert pool.connection_class == UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 1, - "username": None, - "password": None, - } - - def test_db_in_querystring(self): - pool = ConnectionPool.from_url("unix:///socket?db=2", db=1) - assert pool.connection_class == UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 2, - "username": None, - "password": None, - } - - def test_extra_querystring_options(self): - pool = ConnectionPool.from_url("unix:///socket?a=1&b=2") - assert pool.connection_class == UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 0, - "username": None, - "password": None, - "a": "1", - "b": "2", - } diff --git a/tests/cluster/test_node_manager.py b/tests/cluster/test_node_manager.py deleted file mode 100644 index dad324a90..000000000 --- a/tests/cluster/test_node_manager.py +++ /dev/null @@ -1,372 +0,0 @@ -# python std lib -from __future__ import annotations - -import asyncio -import uuid -from unittest.mock import AsyncMock, Mock, patch - -# 3rd party imports -import pytest - -# rediscluster imports -from coredis.client import Redis -from coredis.credentials import UserPassCredentialProvider -from coredis.exceptions import ConnectionError, RedisClusterException, RedisError -from coredis.pool.nodemanager import HASH_SLOTS, ManagedNode, NodeManager - - -async def test_init_slots_cache_not_all_slots(s, redis_cluster): - """ - Test that if not all slots are covered it should raise an exception - """ - - with patch.object(NodeManager, "get_redis_link") as get_redis_link: - cluster_slots_async = asyncio.Future() - cluster_slots = { - (0, 5459): [ - { - "host": "127.0.0.1", - "port": 7000, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7003, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - (5461, 10922): [ - { - "host": "127.0.0.1", - "port": 7001, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7004, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - (10923, 16383): [ - { - "host": "127.0.0.1", - "port": 7002, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7005, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - } - mock_redis = Mock() - cluster_slots_async.set_result(cluster_slots) - mock_redis.cluster_slots.return_value = cluster_slots_async - - config_get_async = asyncio.Future() - config_get_async.set_result({"cluster-require-full-coverage": "yes"}) - - mock_redis.config_get.return_value = config_get_async - - get_redis_link.return_value = mock_redis - with pytest.raises(RedisClusterException) as ex: - await s.connection_pool.initialize() - - assert str(ex.value).startswith("Not all slots are covered after query all startup_nodes.") - - -async def test_init_slots_cache_not_all_slots_not_require_full_coverage(s, redis_cluster): - """ - Test that if not all slots are covered it should raise an exception - """ - with patch.object(Redis, "cluster_slots", new_callable=AsyncMock) as mock_cluster_slots: - with patch.object(Redis, "config_get", new_callable=AsyncMock) as mock_config_get: - mock_config_get.return_value = {"cluster-require-full-coverage": "no"} - mock_cluster_slots.return_value = { - (0, 5459): [ - { - "host": "127.0.0.1", - "port": 7000, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7003, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - (5461, 10922): [ - { - "host": "127.0.0.1", - "port": 7001, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7004, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - (10923, 16383): [ - { - "host": "127.0.0.1", - "port": 7002, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7005, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - } - - await s.connection_pool.nodes.initialize() - assert 5460 not in s.connection_pool.nodes.slots - - -async def test_init_slots_cache(s, redis_cluster): - """ - Test that slots cache can in initialized and all slots are covered - """ - good_slots_resp = { - (0, 5460): [ - { - "host": "127.0.0.1", - "port": 7000, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7003, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - (5461, 10922): [ - { - "host": "127.0.0.1", - "port": 7001, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7004, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - (10923, 16383): [ - { - "host": "127.0.0.1", - "port": 7002, - "node_id": str(uuid.uuid4()), - "server_type": "master", - }, - { - "host": "127.0.0.1", - "port": 7005, - "node_id": str(uuid.uuid4()), - "server_type": "slave", - }, - ], - } - - with patch.object(Redis, "config_get", new_callable=AsyncMock) as mock_config_get: - with patch.object(Redis, "cluster_slots", new_callable=AsyncMock) as mock_cluster_slots: - mock_cluster_slots.return_value = good_slots_resp - mock_config_get.return_value = {"cluster-require-full-coverage": "yes"} - - await s.connection_pool.nodes.initialize() - assert len(s.connection_pool.nodes.slots) == HASH_SLOTS - - for slot_info, node_info in good_slots_resp.items(): - all_hosts = ["127.0.0.1", "127.0.0.2"] - all_ports = [7000, 7001, 7002, 7003, 7004, 7005] - slot_start = slot_info[0] - slot_end = slot_info[1] - - for i in range(slot_start, slot_end + 1): - assert len(s.connection_pool.nodes.slots[i]) == len(node_info) - assert s.connection_pool.nodes.slots[i][0].host in all_hosts - assert s.connection_pool.nodes.slots[i][1].host in all_hosts - assert s.connection_pool.nodes.slots[i][0].port in all_ports - assert s.connection_pool.nodes.slots[i][1].port in all_ports - - assert len(s.connection_pool.nodes.nodes) == 6 - - -async def test_empty_startup_nodes(): - """ - It should not be possible to create a node manager with no nodes specified - """ - with pytest.raises(RedisClusterException): - await NodeManager().initialize() - - with pytest.raises(RedisClusterException): - await NodeManager([]).initialize() - - -async def test_all_nodes(redis_cluster): - """ - Set a list of nodes and it should be possible to iterate over all - """ - n = NodeManager(startup_nodes=[{"host": "127.0.0.1", "port": 7000}]) - await n.initialize() - - nodes = [node for node in n.nodes.values()] - - for i, node in enumerate(n.all_nodes()): - assert node in nodes - - -async def test_all_nodes_primaries(redis_cluster): - """ - Set a list of nodes with random primary/replica config and it shold be possible - to iterate over all of them. - """ - n = NodeManager( - startup_nodes=[ - {"host": "127.0.0.1", "port": 7000}, - {"host": "127.0.0.1", "port": 7001}, - ] - ) - await n.initialize() - - nodes = [node for node in n.nodes.values() if node.server_type == "primary"] - - for node in n.all_primaries(): - assert node in nodes - - -async def test_cluster_slots_error(redis_cluster): - """ - Check that exception is raised if initialize can't execute - 'CLUSTER SLOTS' command. - """ - with patch.object(Redis, "execute_command") as execute_command_mock: - execute_command_mock.side_effect = RedisError("foobar") - - n = NodeManager(startup_nodes=[{"host": "6.6.6.6", "port": 1234}]) - - with pytest.raises(RedisClusterException): - await n.initialize() - - -def test_set_node(): - """ - Test to update data in a slot. - """ - expected = ManagedNode(host="127.0.0.1", port=7000, server_type="primary") - n = NodeManager(startup_nodes=[]) - assert len(n.slots) == 0, "no slots should exist" - res = n.set_node(host="127.0.0.1", port=7000, server_type="primary") - assert res == expected - assert n.nodes == {expected.name: expected} - - -async def test_reset(redis_cluster): - """ - Test that reset method resets variables back to correct default values. - """ - - n = NodeManager(startup_nodes=[]) - n.initialize = AsyncMock() - await n.reset() - assert n.initialize.call_count == 1 - - -async def test_cluster_one_instance(redis_cluster): - """ - If the cluster exists of only 1 node then there is some hacks that must - be validated they work. - """ - with patch.object(Redis, "cluster_slots", new_callable=AsyncMock) as mock_cluster_slots: - with patch.object(Redis, "config_get", new_callable=AsyncMock) as mock_config_get: - mock_config_get.return_value = {"cluster-require-full-coverage": "yes"} - mock_cluster_slots.return_value = { - (0, 16383): [ - { - "host": "", - "port": 7006, - "node_id": str(uuid.uuid4()), - "server_type": "master", - } - ], - } - - n = NodeManager(startup_nodes=[{"host": "127.0.0.1", "port": 7006}]) - await n.initialize() - - del n.nodes["127.0.0.1:7006"].node_id - assert n.nodes == { - "127.0.0.1:7006": ManagedNode(host="127.0.0.1", port=7006, server_type="primary") - } - assert len(n.slots) == 16384 - - for i in range(0, 16384): - assert n.slots[i] == [ - ManagedNode( - host="127.0.0.1", - port=7006, - server_type="primary", - ) - ] - - -async def test_initialize_follow_cluster(redis_cluster): - n = NodeManager( - nodemanager_follow_cluster=True, - startup_nodes=[{"host": "127.0.0.1", "port": 7000}], - ) - n.orig_startup_nodes = None - await n.initialize() - - -async def test_init_with_down_node(redis_cluster): - """ - If I can't connect to one of the nodes, everything should still work. - But if I can't connect to any of the nodes, exception should be thrown. - """ - - def get_redis_link(host, port, decode_responses=False): - if port == 7000: - raise ConnectionError("mock connection error for 7000") - - return Redis(host=host, port=port, decode_responses=decode_responses) - - with patch.object(NodeManager, "get_redis_link", side_effect=get_redis_link): - n = NodeManager(startup_nodes=[{"host": "127.0.0.1", "port": 7000}]) - with pytest.raises(RedisClusterException) as e: - await n.initialize() - assert "Redis Cluster cannot be connected" in str(e.value) - - -async def test_cluster_initialization_fail(redis_cluster_auth, cloner): - with pytest.raises(RedisClusterException, match="invalid username-password pair"): - await cloner(redis_cluster_auth, password="wrong") - - -async def test_cluster_initialization_credential_provider_fail( - redis_cluster_auth_cred_provider, cloner -): - with pytest.raises(RedisClusterException, match="invalid username-password pair"): - await cloner( - redis_cluster_auth_cred_provider, - credential_provider=UserPassCredentialProvider(password="wrong"), - ) diff --git a/tests/cluster/test_pipeline.py b/tests/cluster/test_pipeline.py index 2058ba3ff..a463dd8d1 100644 --- a/tests/cluster/test_pipeline.py +++ b/tests/cluster/test_pipeline.py @@ -4,6 +4,7 @@ import pytest +from coredis._concurrency import gather from coredis.exceptions import ( AuthorizationError, ClusterCrossSlotError, @@ -20,80 +21,73 @@ @targets("redis_cluster") class TestPipeline: async def test_empty_pipeline(self, client): - async with await client.pipeline() as pipe: - assert await pipe.execute() == () - - async def test_pipeline(self, client): - async with await client.pipeline() as pipe: - pipe.set("a", "a1") - pipe.get("a") - pipe.zadd("z", dict(z1=1)) - pipe.zadd("z", dict(z2=4)) - pipe.zincrby("z", "z1", 1) - pipe.zrange("z", 0, 5, withscores=True) - assert await pipe.execute() == ( - True, - "a1", - True, - True, - 2.0, - (("z1", 2.0), ("z2", 4)), - ) + async with client.pipeline(): + pass + + async def test_pipeline_simple(self, client): + async with client.pipeline() as pipe: + a = pipe.set("a", "a1") + b = pipe.get("a") + c = pipe.zadd("z", dict(z1=1)) + d = pipe.zadd("z", dict(z2=4)) + e = pipe.zincrby("z", "z1", 1) + f = pipe.zrange("z", 0, 5, withscores=True) + assert await gather(a, b, c, d, e, f) == ( + True, + "a1", + True, + True, + 2.0, + (("z1", 2.0), ("z2", 4)), + ) async def test_pipeline_length(self, client): - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: # Initially empty. assert len(pipe) == 0 assert pipe - # Fill 'er up! pipe.set("a", "a1") pipe.set("b", "b1") pipe.set("c", "c1") assert len(pipe) == 3 - assert pipe - - # Execute calls reset(), so empty once again. - await pipe.execute() - assert len(pipe) == 0 - assert pipe async def test_pipeline_no_transaction(self, client): - async with await client.pipeline(transaction=False) as pipe: - pipe.set("a", "a1") - pipe.set("b", "b1") - pipe.set("c", "c1") - assert await pipe.execute() == ( - True, - True, - True, - ) - assert await client.get("a") == "a1" - assert await client.get("b") == "b1" - assert await client.get("c") == "c1" + async with client.pipeline(transaction=False) as pipe: + a = pipe.set("a", "a1") + b = pipe.set("b", "b1") + c = pipe.set("c", "c1") + assert await gather(a, b, c) == ( + True, + True, + True, + ) + assert await client.get("a") == "a1" + assert await client.get("b") == "b1" + assert await client.get("c") == "c1" async def test_pipeline_no_permission(self, client, user_client): no_perm_client = await user_client("testuser", "on", "+@all", "-MULTI") - async with await no_perm_client.pipeline(transaction=True) as pipe: - pipe.get("fubar") + async with no_perm_client: with pytest.raises(AuthorizationError): - await pipe.execute() + async with no_perm_client.pipeline(transaction=True) as pipe: + pipe.get("fubar") async def test_unwatch(self, client): await client.set("a{fubar}", "1") await client.set("b{fubar}", "2") - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: await pipe.watch("a{fubar}", "b{fubar}") await client.set("b{fubar}", "3") await pipe.unwatch() assert not pipe.watching - pipe.get("a{fubar}") - assert await pipe.execute() == ("1",) + res = pipe.get("a{fubar}") + assert await res == "1" @pytest.mark.xfail async def test_pipeline_transaction_with_watch_on_construction(self, client): - pipe = await client.pipeline(transaction=True, watches=["a{fu}"]) + pipe = client.pipeline(transaction=True, watches=["a{fu}"]) async def overwrite(): i = 0 @@ -105,63 +99,60 @@ async def overwrite(): except Exception: break - [pipe.set("a{fu}", -1 * i) for i in range(1000)] - task = asyncio.create_task(overwrite()) try: await asyncio.sleep(0.1) with pytest.raises(WatchError): - await pipe.execute() + async with pipe: + [pipe.set("a{fu}", -1 * i) for i in range(1000)] finally: task.cancel() async def test_pipeline_transaction_with_watch(self, client): - pipe = await client.pipeline(transaction=False) - await pipe.watch("a{fu}") - await pipe.watch("b{fu}") - pipe.multi() - await client.set("d{fu}", 1) - pipe.set("a{fu}", 2) - assert (True,) == await pipe.execute() - - async def test_pipeline_transaction_with_watch_inline_fail(self, client): - async with await client.pipeline(transaction=False) as pipe: + async with client.pipeline(transaction=False) as pipe: await pipe.watch("a{fu}") await pipe.watch("b{fu}") pipe.multi() - await client.set("a{fu}", 1) - pipe.set("a{fu}", 2) - with pytest.raises(WatchError): - await pipe.execute() + await client.set("d{fu}", 1) + res = pipe.set("a{fu}", 2) + assert await res + + async def test_pipeline_transaction_with_watch_inline_fail(self, client): + with pytest.raises(WatchError): + async with client.pipeline(transaction=False) as pipe: + await pipe.watch("a{fu}") + await pipe.watch("b{fu}") + pipe.multi() + await client.set("a{fu}", 1) + pipe.set("a{fu}", 2) async def test_pipeline_transaction(self, client): - async with await client.pipeline(transaction=True) as pipe: - pipe.set("a{fu}", "a1") - pipe.set("b{fu}", "b1") - pipe.set("c{fu}", "c1") - assert await pipe.execute() == ( - True, - True, - True, - ) - assert await client.get("a{fu}") == "a1" - assert await client.get("b{fu}") == "b1" - assert await client.get("c{fu}") == "c1" + async with client.pipeline(transaction=True) as pipe: + a = pipe.set("a{fu}", "a1") + b = pipe.set("b{fu}", "b1") + c = pipe.set("c{fu}", "c1") + assert await gather(a, b, c) == ( + True, + True, + True, + ) + assert await client.get("a{fu}") == "a1" + assert await client.get("b{fu}") == "b1" + assert await client.get("c{fu}") == "c1" async def test_pipeline_transaction_cross_slot(self, client): with pytest.raises(ClusterTransactionError): - async with await client.pipeline(transaction=True) as pipe: + async with client.pipeline(transaction=True) as pipe: pipe.set("a{fu}", "a1") pipe.set("b{fu}", "b1") pipe.set("c{fu}", "c1") pipe.set("a{bar}", "fail!") - await pipe.execute() assert await client.exists(["a{fu}", "b{fu}", "c{fu}"]) == 0 assert await client.exists(["a{bar}"]) == 0 async def test_pipeline_eval(self, client): - async with await client.pipeline(transaction=False) as pipe: - pipe.eval( + async with client.pipeline(transaction=False) as pipe: + eval_res = pipe.eval( "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", [ "A{foo}", @@ -172,11 +163,11 @@ async def test_pipeline_eval(self, client): "second", ], ) - res = (await pipe.execute())[0] - assert res[0] == "A{foo}" - assert res[1] == "B{foo}" - assert res[2] == "first" - assert res[3] == "second" + res = await eval_res + assert res[0] == "A{foo}" + assert res[1] == "B{foo}" + assert res[2] == "first" + assert res[3] == "second" async def test_exec_error_in_response(self, client): """ @@ -184,73 +175,54 @@ async def test_exec_error_in_response(self, client): to the list of returned values """ await client.set("c", "a") - async with await client.pipeline() as pipe: - pipe.set("a", "1") - pipe.set("b", 2) - # pipe.set("b", "2") - pipe.lpush("c", ["3"]) - pipe.set("d", "4") - result = await pipe.execute(raise_on_error=False) - - assert result[0] - assert await client.get("a") == "1" - assert result[1] - assert await client.get("b") == "2" - - # we can't lpush to a key that's a string value, so this should - # be a ResponseError exception - assert isinstance(result[2], ResponseError) - assert await client.get("c") == "a" - - # since this isn't a transaction, the other commands after the - # error are still executed - assert result[3] - assert await client.get("d") == "4" - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" + async with client.pipeline(raise_on_error=False) as pipe: + a = pipe.set("a", "1") + b = pipe.set("b", 2) + c = pipe.lpush("c", ["3"]) + d = pipe.set("d", "4") + + assert await a + assert await client.get("a") == "1" + assert await b + assert await client.get("b") == "2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(await c, ResponseError) + assert await client.get("c") == "a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert await d + assert await client.get("d") == "4" async def test_exec_error_raised(self, client): await client.set("c", "a") - async with await client.pipeline() as pipe: - pipe.set("a", "1") - pipe.set("b", "2") - pipe.lpush("c", ["3"]) - pipe.set("d", "4") - with pytest.raises(ResponseError) as ex: - await pipe.execute() - assert str(ex.value).startswith("Command # 3 (LPUSH c 3) of pipeline caused error: ") - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" + with pytest.raises(ResponseError) as ex: + async with client.pipeline() as pipe: + pipe.set("a", "1") + pipe.set("b", "2") + pipe.lpush("c", ["3"]) + pipe.set("d", "4") + assert str(ex.value).startswith("Command # 3 (LPUSH c 3) of pipeline caused error: ") async def test_parse_error_raised(self, client): - async with await client.pipeline() as pipe: - # the zrem is invalid because we don't pass any keys to it - pipe.set("a", "1") - pipe.zrem("b", []) - pipe.set("b", "2") - with pytest.raises(ResponseError) as ex: - await pipe.execute() - - assert str(ex.value).startswith("Command # 2 (ZREM b) of pipeline caused error: ") + with pytest.raises(ResponseError) as ex: + async with client.pipeline() as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", "1") + pipe.zrem("b", []) + pipe.set("b", "2") - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" + assert str(ex.value).startswith("Command # 2 (ZREM b) of pipeline caused error: ") @pytest.mark.parametrize("cluster_remap_keyslots", [("a{fu}", "b{fu}", "c{bar}", "d{bar}")]) async def test_moved_error_retried(self, client, cluster_remap_keyslots, _s): - async with await client.pipeline() as pipe: - pipe.set("a{fu}", 1) - pipe.get("a{fu}") + async with client.pipeline() as pipe: + a = pipe.set("a{fu}", 1) + b = pipe.get("a{fu}") - assert (True, _s("1")) == await pipe.execute() + assert (True, _s("1")) == await gather(a, b) @pytest.mark.parametrize( "function, args, kwargs", @@ -265,9 +237,8 @@ async def test_moved_error_retried(self, client, cluster_remap_keyslots, _s): ) async def test_no_key_command(self, client, function, args, kwargs): with pytest.raises(RedisClusterException) as exc: - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: function(pipe, *args, **kwargs) - await pipe.execute() exc.match("No way to dispatch (.*?) to Redis Cluster. Missing key") @pytest.mark.parametrize( @@ -279,125 +250,51 @@ async def test_no_key_command(self, client, function, args, kwargs): ) async def test_multi_key_cross_slot_commands(self, client, function, args, kwargs): with pytest.raises(ClusterCrossSlotError) as exc: - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: function(pipe, *args, **kwargs) - await pipe.execute() exc.match("Keys in request don't hash to the same slot") @pytest.mark.parametrize( "function, args, kwargs, expectation", [ - (ClusterPipeline.bitop, (["a{fu}"], "not", "b{fu}"), {}, (0,)), - (ClusterPipeline.brpoplpush, ("a{fu}", "b{fu}", 1.0), {}, (None,)), + (ClusterPipeline.bitop, (["a{fu}"], "not", "b{fu}"), {}, 0), + (ClusterPipeline.brpoplpush, ("a{fu}", "b{fu}", 1.0), {}, None), ], ) async def test_multi_key_non_cross_slot(self, client, function, args, kwargs, expectation): - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: pipe.set("x{fu}", 1) - function(pipe, *args, **kwargs) - res = await pipe.execute() - assert res == (True,) + expectation + res = function(pipe, *args, **kwargs) + assert await res == expectation assert await client.get("x{fu}") == "1" async def test_multi_node_pipeline(self, client): - async with await client.pipeline() as pipe: - pipe.set("x{foo}", 1) - pipe.set("x{bar}", 1) - pipe.set("x{baz}", 1) - res = await pipe.execute() - assert res == (True, True, True) + async with client.pipeline() as pipe: + a = pipe.set("x{foo}", 1) + b = pipe.set("x{bar}", 1) + c = pipe.set("x{baz}", 1) + assert (True, True, True) == await gather(a, b, c) async def test_multi_node_pipeline_partially_correct(self, client): await client.lpush("list{baz}", [1, 2, 3]) with pytest.raises(ClusterCrossSlotError) as exc: - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: pipe.set("x{foo}", 1) pipe.set("x{bar}", 1) pipe.set("x{baz}", 1) pipe.brpoplpush("list{baz}", "list{foo}", 1.0) - await pipe.execute() exc.match("Keys in request don't hash to the same slot") assert await client.get("x{foo}") is None assert await client.get("x{bar}") is None assert await client.get("x{baz}") is None - async def test_transaction_callable(self, client, cloner): - clone = await cloner(client) - - async def _incr(): - for i in range(10): - await clone.incr("a{fubar}") - - await client.set("a{fubar}", "1") - await client.set("b{fubar}", "2") - - async def my_transaction(pipe): - await asyncio.sleep(0) - a_value = await pipe.get("a{fubar}") - b_value = await pipe.get("b{fubar}") - pipe.multi() - pipe.set("c{fubar}", str(int(a_value) + int(b_value))) - - results = await asyncio.gather( - client.transaction(my_transaction, "a{fubar}", "b{fubar}", watch_delay=0.01), - _incr(), - ) - assert results[0] == (True,) - assert int(await client.get("c{fubar}")) > 3 - - async def test_transaction_callable_access_other_node(self, client, cloner): - clone = await cloner(client) - - async def _incr(): - for i in range(10): - await clone.incr("a{fubar}") - - await client.set("a{fubar}", "1") - await client.set("b{fubar}", "2") - await client.set("c{bazbaz}", "3") - - async def my_transaction(pipe): - await asyncio.sleep(0) - a_value = await pipe.get("a{fubar}") - b_value = await pipe.get("b{fubar}") - c_value = await pipe.get("c{bazbaz}") - - pipe.multi() - - pipe.set("c{fubar}", str(int(a_value) + int(b_value) + int(c_value))) - - results = await asyncio.gather( - client.transaction(my_transaction, "a{fubar}", "b{fubar}", watch_delay=0.01), - _incr(), - ) - assert results[0] == (True,) - assert int(await client.get("c{fubar}")) > 3 - - async def test_transaction_callable_crossslot_fail(self, client, cloner): - async def my_transaction(pipe): - pipe.multi() - pipe.get("a{bazbaz}") - - with pytest.raises(ClusterCrossSlotError): - await client.transaction( - my_transaction, "a{fubar}", "b{fubar}", "c{bazbaz}", watch_delay=0.01 - ) - - with pytest.raises(ClusterTransactionError): - await client.transaction(my_transaction, "a{fubar}", "b{fubar}", watch_delay=0.01) - async def test_pipeline_timeout(self, client): - await client.hset("hash", {str(i): i for i in range(4096)}) - await client.ping() - pipeline = await client.pipeline(timeout=0.01) - for i in range(20): - pipeline.hgetall("hash") + await client.hset("hash", {str(i): bytes(1024) for i in range(1024)}) with pytest.raises(TimeoutError): - await pipeline.execute() - - await client.ping() - pipeline = await client.pipeline(timeout=5) - for i in range(20): - pipeline.hgetall("hash") - await pipeline.execute() + async with client.pipeline(timeout=0.01) as pipeline: + for _ in range(20): + pipeline.hgetall("hash") + async with client.pipeline(timeout=5) as pipeline: + for _ in range(20): + pipeline.hgetall("hash") diff --git a/tests/cluster/test_pubsub.py b/tests/cluster/test_pubsub.py index 22fb47e09..568bdc63b 100644 --- a/tests/cluster/test_pubsub.py +++ b/tests/cluster/test_pubsub.py @@ -1,16 +1,12 @@ # python std lib from __future__ import annotations -import asyncio import time - -# 3rd party imports from collections import Counter -from contextlib import aclosing +import anyio import pytest -# rediscluster imports from coredis._utils import b, hash_slot from tests.conftest import targets @@ -26,7 +22,7 @@ async def wait_for_message(pubsub, timeout=1, ignore_subscribe_messages=False): if message is not None: return message - await asyncio.sleep(0.01) + await anyio.sleep(0.01) now = time.time() return None @@ -86,7 +82,7 @@ async def _test_subscribe_unsubscribe( for i, key in enumerate(keys): if sharded: - node_key = p.connection_pool.nodes.node_from_slot(hash_slot(b(key)))["node_id"] + node_key = p.connection_pool.nodes.node_from_slot(hash_slot(b(key))).node_id else: node_key = "legacy" counter[node_key] += 1 @@ -98,13 +94,9 @@ async def _test_subscribe_unsubscribe( received.clear() for key in keys: assert await unsub_func(key) is None - - # should be a message for each channel/pattern we just unsubscribed - # from - for i, key in enumerate(keys): if sharded: - node_key = p.connection_pool.nodes.node_from_slot(hash_slot(b(key)))["node_id"] + node_key = p.connection_pool.nodes.node_from_slot(hash_slot(b(key))).node_id else: node_key = "legacy" counter[node_key] -= 1 @@ -117,12 +109,10 @@ async def test_channel_subscribe_unsubscribe(self, redis_cluster): await self._test_subscribe_unsubscribe(**kwargs) @pytest.mark.min_server_version("7.0") - @pytest.mark.xfail async def test_sharded_channel_subscribe_unsubscribe(self, redis_cluster): kwargs = make_subscribe_test_data(redis_cluster.sharded_pubsub(), "channel", sharded=True) await self._test_subscribe_unsubscribe(**kwargs, sharded=True) - @pytest.mark.xfail async def test_pattern_subscribe_unsubscribe(self, redis_cluster): kwargs = make_subscribe_test_data(redis_cluster.pubsub(), "pattern") await self._test_subscribe_unsubscribe(**kwargs) @@ -135,14 +125,11 @@ async def _test_resubscribe_on_reconnection( for key in keys: assert await sub_func(key) is None - - # should be a message for each channel/pattern we just subscribed to - expected = set() received = set() for i, key in enumerate(keys): if sharded: - node_key = p.connection_pool.nodes.node_from_slot(hash_slot(b(key)))["node_id"] + node_key = p.connection_pool.nodes.node_from_slot(hash_slot(b(key))).node_id else: node_key = "legacy" counter[node_key] += 1 @@ -150,25 +137,15 @@ async def _test_resubscribe_on_reconnection( received.add(tuple((await wait_for_message(p)).items())) assert expected == received - - # manually disconnect if sharded: - [c.disconnect() for c in p.shard_connections.values()] + [await c.connection.send_eof() for c in p.shard_connections.values()] else: - p.connection.disconnect() + await p.connection.connection.send_eof() - # calling get_message again reconnects and resubscribes - # note, we may not re-subscribe to channels in exactly the same order - # so we have to do some extra checks to make sure we got them all messages = [] - - # we'll figure this out eventually - if sharded: - await asyncio.sleep(1) - + await anyio.sleep(1) for i, _ in enumerate(keys): messages.append(await wait_for_message(p)) - unique_channels = set() assert len(messages) == len(keys) @@ -230,7 +207,6 @@ async def test_resubscribe_to_channels_on_reconnection(self, redis_cluster): await self._test_resubscribe_on_reconnection(**kwargs) @pytest.mark.min_server_version("7.0") - @pytest.mark.xfail async def test_sharded_resubscribe_to_channels_on_reconnection(self, redis_cluster): kwargs = make_subscribe_test_data(redis_cluster.sharded_pubsub(), "channel", sharded=True) await self._test_resubscribe_on_reconnection(**kwargs, sharded=True) @@ -327,29 +303,6 @@ async def test_ignore_individual_subscribe_messages(self, redis_cluster): assert message is None assert p.subscribed is False - async def test_uninitialized_client(self, redis_cluster, cloner): - client = await cloner(redis_cluster, initialize=False) - async with aclosing(client.pubsub()) as p: - assert not client.connection_pool.initialized - await p.subscribe("foo") - assert p.subscribed - assert await p.get_message(ignore_subscribe_messages=True, timeout=1) is None - await p.unsubscribe() - assert await p.get_message(ignore_subscribe_messages=True, timeout=1) is None - assert not p.subscribed - - @pytest.mark.min_server_version("7.0") - async def test_sharded_pubsub_uninitialized_client(self, redis_cluster, cloner): - client = await cloner(redis_cluster, initialize=False) - async with aclosing(client.sharded_pubsub()) as p: - assert not client.connection_pool.initialized - await p.subscribe("foo") - assert await p.get_message(ignore_subscribe_messages=True, timeout=1) is None - assert p.subscribed - await p.unsubscribe() - assert await p.get_message(ignore_subscribe_messages=True, timeout=1) is None - assert not p.subscribed - class TestPubSubMessages: """ @@ -444,7 +397,6 @@ async def test_unicode_channel_message_handler(self, redis_cluster): assert await wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") - @pytest.mark.xfail async def test_unicode_pattern_message_handler(self, redis_cluster): async with redis_cluster.pubsub(ignore_subscribe_messages=True) as p: pattern = "uni" + chr(4456) + "*" @@ -468,15 +420,13 @@ async def collect(): [messages.append(message) async for message in p] async def unsubscribe(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await p.punsubscribe("fu*") await p.unsubscribe("test") - completed, pending = await asyncio.wait( - [asyncio.create_task(collect()), asyncio.create_task(unsubscribe())], timeout=1 - ) - assert all(task.done() for task in completed) - assert not pending + async with anyio.create_task_group() as tg: + tg.start_soon(collect) + tg.start_soon(unsubscribe) assert len(messages) == 20 async def test_sharded_pubsub_message_iterator(self, redis_cluster): @@ -489,14 +439,13 @@ async def collect(): [messages.append(message) async for message in p] async def unsubscribe(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await p.unsubscribe("test") - completed, pending = await asyncio.wait( - [asyncio.create_task(collect()), asyncio.create_task(unsubscribe())], timeout=1 - ) - assert all(task.done() for task in completed) - assert not pending + async with anyio.create_task_group() as tg: + tg.start_soon(collect) + tg.start_soon(unsubscribe) + assert len(messages) == 10 async def test_pubsub_handlers(self, redis_cluster): @@ -512,7 +461,7 @@ def handler(message): await redis_cluster.publish("fu", "bar") await redis_cluster.publish("bar", "fu") - await asyncio.sleep(0.1) + await anyio.sleep(0.1) assert messages == {"fu", "bar"} @@ -536,35 +485,27 @@ async def test_pubsub_shardchannels(self, client, _s): @pytest.mark.min_server_version("7.0.0") async def test_pubsub_shardnumsub(self, client, _s): - p1 = client.sharded_pubsub(ignore_subscribe_messages=True) - await p1.subscribe("foo", "bar", "baz") - p2 = client.sharded_pubsub(ignore_subscribe_messages=True) - await p2.subscribe("bar", "baz") - p3 = client.sharded_pubsub(ignore_subscribe_messages=True) - await p3.subscribe("baz") - - channels = {_s("foo"): 1, _s("bar"): 2, _s("baz"): 3} - assert channels == await client.pubsub_shardnumsub("foo", "bar", "baz") - await p1.unsubscribe() - await p2.unsubscribe() - await p3.unsubscribe() - await p1.aclose() - await p2.aclose() - await p3.aclose() + async with ( + client.sharded_pubsub(ignore_subscribe_messages=True) as p1, + client.sharded_pubsub(ignore_subscribe_messages=True) as p2, + client.sharded_pubsub(ignore_subscribe_messages=True) as p3, + ): + await p1.subscribe("foo", "bar", "baz") + await p2.subscribe("bar", "baz") + await p3.subscribe("baz") + + channels = {_s("foo"): 1, _s("bar"): 2, _s("baz"): 3} + assert channels == await client.pubsub_shardnumsub("foo", "bar", "baz") async def test_pubsub_numsub(self, client, _s): - p1 = client.pubsub(ignore_subscribe_messages=True) - await p1.subscribe("foo", "bar", "baz") - p2 = client.pubsub(ignore_subscribe_messages=True) - await p2.subscribe("bar", "baz") - p3 = client.pubsub(ignore_subscribe_messages=True) - await p3.subscribe("baz") - - channels = {_s("foo"): 1, _s("bar"): 2, _s("baz"): 3} - assert channels == await client.pubsub_numsub("foo", "bar", "baz") - await p1.unsubscribe() - await p2.unsubscribe() - await p3.unsubscribe() - await p1.aclose() - await p2.aclose() - await p3.aclose() + async with ( + client.pubsub(ignore_subscribe_messages=True) as p1, + client.pubsub(ignore_subscribe_messages=True) as p2, + client.pubsub(ignore_subscribe_messages=True) as p3, + ): + await p1.subscribe("foo", "bar", "baz") + await p2.subscribe("bar", "baz") + await p3.subscribe("baz") + + channels = {_s("foo"): 1, _s("bar"): 2, _s("baz"): 3} + assert channels == await client.pubsub_numsub("foo", "bar", "baz") diff --git a/tests/cluster/test_scripting.py b/tests/cluster/test_scripting.py index 8f4a8480f..fd3783ff1 100644 --- a/tests/cluster/test_scripting.py +++ b/tests/cluster/test_scripting.py @@ -50,9 +50,10 @@ async def test_eval_ro(self, cloner, client, _s): clone = await cloner(client, read_from_replicas=True) await client.set("a", 2) # 2 * 3 == 6 - assert await clone.eval_ro(multiply_script, ["a"], [3]) == 6 - with pytest.raises(ResponseError, match="Write commands are not allowed"): - await clone.eval_ro(multiply_and_set_script, ["a"], [3]) + async with clone: + assert await clone.eval_ro(multiply_script, ["a"], [3]) == 6 + with pytest.raises(ResponseError, match="Write commands are not allowed"): + await clone.eval_ro(multiply_and_set_script, ["a"], [3]) async def test_eval_same_slot(self, client): await client.set("A{foo}", 2) diff --git a/tests/commands/test_acl.py b/tests/commands/test_acl.py index b882094b0..ce5a91c6a 100644 --- a/tests/commands/test_acl.py +++ b/tests/commands/test_acl.py @@ -14,13 +14,10 @@ async def teardown(client): @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_auth", "redis_auth_cred_provider", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "valkey", "redict", diff --git a/tests/commands/test_bitmap.py b/tests/commands/test_bitmap.py index d8d212a28..05617821d 100644 --- a/tests/commands/test_bitmap.py +++ b/tests/commands/test_bitmap.py @@ -9,11 +9,8 @@ @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "valkey", "redict", diff --git a/tests/commands/test_cluster.py b/tests/commands/test_cluster.py index 4634ab52e..c786aee1f 100644 --- a/tests/commands/test_cluster.py +++ b/tests/commands/test_cluster.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio - import pytest from coredis import PureToken @@ -14,27 +12,27 @@ @targets( "redis_cluster", "redis_cluster_noreplica", - "redis_cluster_blocking", "redis_cluster_raw", "redis_cluster_ssl", ) class TestCluster: async def test_addslots(self, client, _s): node = client.connection_pool.get_primary_node_by_slot(1) - client = client.connection_pool.nodes.get_redis_link(node.host, node.port) - with pytest.raises(ResponseError, match="Slot 1 is already busy"): - await client.cluster_addslots([1]) + async with client.connection_pool.nodes.get_redis_link(node.host, node.port) as node_client: + with pytest.raises(ResponseError, match="Slot 1 is already busy"): + await node_client.cluster_addslots([1]) @pytest.mark.min_server_version("7.0.0") async def test_addslots_range(self, client, _s): node = client.connection_pool.get_primary_node_by_slot(1) - client = client.connection_pool.nodes.get_redis_link(node.host, node.port) - with pytest.raises(ResponseError, match="Slot 1 is already busy"): - await client.cluster_addslotsrange([(1, 2)]) + async with client.connection_pool.nodes.get_redis_link(node.host, node.port) as node_client: + with pytest.raises(ResponseError, match="Slot 1 is already busy"): + await node_client.cluster_addslotsrange([(1, 2)]) async def test_asking(self, client, _s): node = client.connection_pool.get_primary_node_by_slot(1) - assert await client.connection_pool.nodes.get_redis_link(node.host, node.port).asking() + async with client.connection_pool.nodes.get_redis_link(node.host, node.port) as node_client: + assert await node_client.asking() async def test_count_failure_reports(self, client, _s): node = client.connection_pool.get_primary_node_by_slot(1) @@ -45,21 +43,20 @@ async def test_count_failure_reports(self, client, _s): async def test_cluster_delslots(self, client, _s): node = client.connection_pool.get_primary_node_by_slot(1) assert await client.cluster_delslots([1]) - assert await client.connection_pool.nodes.get_redis_link( - node.host, node.port - ).cluster_addslots([1]) + async with client.connection_pool.nodes.get_redis_link(node.host, node.port) as node_client: + assert await node_client.cluster_addslots([1]) @pytest.mark.min_server_version("7.0.0") async def test_cluster_delslots_range(self, client, _s): node = client.connection_pool.get_primary_node_by_slot(1) node_last = client.connection_pool.get_primary_node_by_slot(16000) assert await client.cluster_delslotsrange([(1, 2), (16000, 16001)]) - assert await client.connection_pool.nodes.get_redis_link( - node.host, node.port - ).cluster_addslots([1, 2]) - assert await client.connection_pool.nodes.get_redis_link( + async with client.connection_pool.nodes.get_redis_link(node.host, node.port) as node_client: + assert await node_client.cluster_addslots([1, 2]) + async with client.connection_pool.nodes.get_redis_link( node_last.host, node_last.port - ).cluster_addslots([16000, 16001]) + ) as node_client: + assert await node_client.cluster_addslots([16000, 16001]) @pytest.mark.xfail @pytest.mark.replicated_clusteronly @@ -67,25 +64,27 @@ async def test_readonly_explicit(self, client, _s): await client.set("fubar", 1) slot = hash_slot(b"fubar") node = client.connection_pool.get_replica_node_by_slot(slot, replica_only=True) - node_client = client.connection_pool.nodes.get_redis_link(node.host, node.port) - with pytest.raises(MovedError): - await node_client.get("fubar") - await node_client.readonly() - await node_client.get("fubar") == _s(1) - await node_client.readwrite() - with pytest.raises(MovedError): - await node_client.get("fubar") + async with client.connection_pool.nodes.get_redis_link(node.host, node.port) as node_client: + with pytest.raises(MovedError): + await node_client.get("fubar") + await node_client.readonly() + await node_client.get("fubar") == _s(1) + await node_client.readwrite() + with pytest.raises(MovedError): + await node_client.get("fubar") @pytest.mark.replicated_clusteronly async def test_cluster_info(self, client, _s): info = await client.cluster_info() assert info["cluster_state"] == "ok" - info = await list(client.replicas)[0].cluster_info() - assert info["cluster_state"] == "ok" + async with list(client.replicas)[0] as node_client: + info = await node_client.cluster_info() + assert info["cluster_state"] == "ok" - info = await list(client.primaries)[0].cluster_info() - assert info["cluster_state"] == "ok" + async with list(client.primaries)[0] as node_client: + info = await node_client.cluster_info() + assert info["cluster_state"] == "ok" async def test_cluster_keyslot(self, client, _s): slot = await client.cluster_keyslot("a") @@ -112,32 +111,39 @@ async def test_cluster_nodes(self, client, _s): @pytest.mark.replicated_clusteronly async def test_cluster_links(self, client, _s): links = [] + for node in client.primaries: - links.append(await node.cluster_links()) + async with node: + links.append(await node.cluster_links()) + for node in client.replicas: - links.append(await node.cluster_links()) + async with node: + links.append(await node.cluster_links()) assert len(links) > 0 async def test_cluster_meet(self, client, _s): node = list(client.primaries)[0] other = list(client.primaries)[1].connection_pool.connection_kwargs - assert await node.cluster_meet(other["host"], other["port"]) - with pytest.raises(ResponseError, match="Invalid node address"): - await node.cluster_meet("bogus", 6666) + async with node: + assert await node.cluster_meet(other["host"], other["port"]) + with pytest.raises(ResponseError, match="Invalid node address"): + await node.cluster_meet("bogus", 6666) async def test_cluster_my_id(self, client, _s): ids = [] + for node in client.primaries: - ids.append(node.cluster_myid()) + async with node: + ids.append(await node.cluster_myid()) + for node in client.replicas: - ids.append(node.cluster_myid()) - ids = await asyncio.gather(*ids) + async with node: + ids.append(await node.cluster_myid()) known_nodes = (_s(node.node_id) for node in client.connection_pool.nodes.all_nodes()) assert set(ids) == set(known_nodes) @pytest.mark.min_server_version("7.0.0") async def test_cluster_shards(self, client, _s): - await client shards = await client.cluster_shards() assert shards assert _s("slots") in shards[0] diff --git a/tests/commands/test_connection.py b/tests/commands/test_connection.py index 98c8ba4fe..07ac5ed37 100644 --- a/tests/commands/test_connection.py +++ b/tests/commands/test_connection.py @@ -1,19 +1,17 @@ from __future__ import annotations -import asyncio - +import anyio import pytest +from exceptiongroup import catch -import coredis from coredis import PureToken +from coredis.client.basic import Redis from coredis.exceptions import AuthenticationFailureError, ResponseError, UnblockedError from tests.conftest import targets @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "valkey", "redict", @@ -21,11 +19,11 @@ class TestConnection: @pytest.mark.xfail async def test_bgsave(self, client): - await asyncio.sleep(0.5) + await anyio.sleep(0.5) assert await client.bgsave() with pytest.raises(ResponseError, match="already in progress"): await client.bgsave() - await asyncio.sleep(0.5) + await anyio.sleep(0.5) assert await client.bgsave(schedule=True) async def test_ping(self, client, _s): @@ -37,12 +35,12 @@ async def test_hello_no_args(self, client, _s): assert resp[_s("server")] is not None async def test_hello_extended(self, client, _s): - resp = await client.hello(client.protocol_version) - assert resp[_s("proto")] == client.protocol_version - await client.hello(client.protocol_version, setname="coredis") + resp = await client.hello(3) + assert resp[_s("proto")] == 3 + await client.hello(3, setname="coredis") assert await client.client_getname() == _s("coredis") with pytest.raises(AuthenticationFailureError): - await client.hello(client.protocol_version, username="no", password="body") + await client.hello(3, username="no", password="body") async def test_ping_custom_message(self, client, _s): resp = await client.ping(message="PANG") @@ -77,84 +75,89 @@ async def test_client_no_touch(self, client, _s): assert await client.client_no_touch(PureToken.OFF) async def test_client_tracking(self, client, _s, cloner): - clone = await cloner(client) - clone_connection = await clone.connection_pool.get_connection("tracking") - clone_id = clone_connection.client_id - assert await client.client_tracking(PureToken.ON, redirect=clone_id, noloop=True) - assert clone_id == await client.client_getredir() - assert await client.client_tracking(PureToken.OFF) - assert -1 == await client.client_getredir() - with pytest.raises(ResponseError, match="does not exist"): - clients = await client.client_list() - invalid_client_id = max(c["id"] for c in clients) + 100 - await client.client_tracking(PureToken.ON, redirect=invalid_client_id) - assert await client.client_tracking(PureToken.ON, bcast=True, redirect=clone_id) - assert await client.client_tracking(PureToken.OFF) - assert await client.client_tracking( - PureToken.ON, "fu:", "bar:", bcast=True, redirect=clone_id - ) - assert await client.client_tracking(PureToken.OFF) - with pytest.raises(ResponseError, match="'fu' overlaps"): - assert await client.client_tracking( - PureToken.ON, "fu", "fuu", bcast=True, redirect=clone_id - ) - assert await client.client_tracking(PureToken.ON, optin=True, redirect=clone_id) - with pytest.raises(ResponseError, match="in OPTOUT mode"): - await client.client_caching(PureToken.NO) - assert await client.client_tracking(PureToken.ON, optin=True, redirect=clone_id) - assert await client.client_caching(PureToken.YES) - - with pytest.raises(ResponseError, match="You can't switch"): - await client.client_tracking(PureToken.ON, optout=True, redirect=clone_id) - assert await client.client_tracking(PureToken.OFF) - assert await client.client_tracking(PureToken.ON, optout=True, redirect=clone_id) - with pytest.raises(ResponseError, match="in OPTIN mode"): - await client.client_caching(PureToken.YES) - assert await client.client_tracking(PureToken.ON, optout=True, redirect=clone_id) - assert await client.client_caching(PureToken.NO) + async with await cloner(client) as clone: + async with clone.connection_pool.acquire() as clone_connection: + clone_id = clone_connection.client_id + assert await client.client_tracking(PureToken.ON, redirect=clone_id, noloop=True) + assert clone_id == await client.client_getredir() + assert await client.client_tracking(PureToken.OFF) + assert -1 == await client.client_getredir() + with pytest.raises(ResponseError, match="does not exist"): + clients = await client.client_list() + invalid_client_id = max(c["id"] for c in clients) + 100 + await client.client_tracking(PureToken.ON, redirect=invalid_client_id) + assert await client.client_tracking(PureToken.ON, bcast=True, redirect=clone_id) + assert await client.client_tracking(PureToken.OFF) + assert await client.client_tracking( + PureToken.ON, "fu:", "bar:", bcast=True, redirect=clone_id + ) + assert await client.client_tracking(PureToken.OFF) + with pytest.raises(ResponseError, match="'fu' overlaps"): + assert await client.client_tracking( + PureToken.ON, "fu", "fuu", bcast=True, redirect=clone_id + ) + assert await client.client_tracking(PureToken.ON, optin=True, redirect=clone_id) + with pytest.raises(ResponseError, match="in OPTOUT mode"): + await client.client_caching(PureToken.NO) + assert await client.client_tracking(PureToken.ON, optin=True, redirect=clone_id) + assert await client.client_caching(PureToken.YES) + + with pytest.raises(ResponseError, match="You can't switch"): + await client.client_tracking(PureToken.ON, optout=True, redirect=clone_id) + assert await client.client_tracking(PureToken.OFF) + assert await client.client_tracking(PureToken.ON, optout=True, redirect=clone_id) + with pytest.raises(ResponseError, match="in OPTIN mode"): + await client.client_caching(PureToken.YES) + assert await client.client_tracking(PureToken.ON, optout=True, redirect=clone_id) + assert await client.client_caching(PureToken.NO) async def test_client_getredir(self, client, _s, cloner): assert await client.client_getredir() == -1 clone = await cloner(client) - clone_id = (await clone.client_info())["id"] - assert await client.client_tracking(PureToken.ON, redirect=clone_id) - assert await client.client_getredir() == clone_id + async with clone: + clone_id = (await clone.client_info())["id"] + assert await client.client_tracking(PureToken.ON, redirect=clone_id) + assert await client.client_getredir() == clone_id async def test_client_pause_unpause(self, client, _s, cloner): - clone = await cloner(client) - assert await clone.client_pause(1000) - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(clone.ping(), timeout=0.01) - assert await client.client_unpause() - assert await clone.ping() == _s("PONG") - assert await clone.client_pause(1000, PureToken.WRITE) - assert not await clone.get("fubar") - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(clone.set("fubar", 1), timeout=0.01) - assert await client.client_unpause() - assert await clone.set("fubar", 1) - - @pytest.mark.xfail - async def test_client_unblock(self, client, cloner): - clone = await cloner(client) - client_id = await clone.client_id() - - async def unblock(): - await asyncio.sleep(0.1) - return await client.client_unblock(client_id, PureToken.ERROR) - - sleeper = asyncio.create_task(clone.brpop(["notexist"], 1000)) - unblocker = asyncio.create_task(unblock()) - await asyncio.wait( - [ - sleeper, - unblocker, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - assert isinstance(sleeper.exception(), UnblockedError) - assert unblocker.result() - assert not await client.client_unblock(client_id, PureToken.ERROR) + async with await cloner(client) as clone: + assert await clone.client_pause(1000) + with pytest.raises(TimeoutError): + with anyio.fail_after(0.01): + await clone.ping() + assert await client.client_unpause() + assert await clone.ping() == _s("PONG") + assert await clone.client_pause(1000, PureToken.WRITE) + assert not await clone.get("fubar") + with pytest.raises(TimeoutError): + with anyio.fail_after(0.01): + await clone.set("fubar", 1) + assert await client.client_unpause() + assert await clone.set("fubar", 1) + + async def test_client_unblock(self, client: Redis, cloner): + async with await cloner(client) as clone: + client_id = await clone.client_id() + + async def unblock(): + await anyio.sleep(0.1) + return await client.client_unblock(client_id, PureToken.ERROR) + + async def blocking(): + await clone.brpop(["notexist"], 1000) + + unblocked = False + + def unblocked_raised(_): + nonlocal unblocked + unblocked = True + + with catch({UnblockedError: unblocked_raised}): + async with anyio.create_task_group() as tg: + tg.start_soon(blocking) + tg.start_soon(unblock) + assert unblocked + assert not await client.client_unblock(client_id, PureToken.ERROR) async def test_client_trackinginfo_no_tracking(self, client, _s): info = await client.client_trackinginfo() @@ -185,35 +188,35 @@ async def test_client_kill_fail(self, client, _s): await client.client_kill(ip_port="1.1.1.1:9999") async def test_client_kill_filter(self, client, cloner, _s): - clone = await cloner(client) - clone_id = (await clone.client_info())["id"] - assert await client.client_kill(identifier=clone_id) > 0 - with pytest.raises(ResponseError, match="No such user"): - await client.client_kill(user="noexist") == 0 + async with await cloner(client) as clone: + clone_id = (await clone.client_info())["id"] + assert await client.client_kill(identifier=clone_id) > 0 + with pytest.raises(ResponseError, match="No such user"): + await client.client_kill(user="noexist") == 0 - clone_addr = (await clone.client_info())["addr"] - assert await client.client_kill(addr=clone_addr) == 1 + clone_addr = (await clone.client_info())["addr"] + assert await client.client_kill(addr=clone_addr) == 1 async def test_client_kill_filter_skip_me(self, client, cloner, _s): - clone = await cloner(client) - my_id = (await client.client_info())["id"] - clone_id = (await clone.client_info())["id"] - laddr = (await client.client_info())["laddr"] - resp = await client.client_kill(laddr=laddr, skipme=True) - assert resp > 0 - await clone.ping() - assert clone_id != (await clone.client_info())["id"] - assert my_id == (await client.client_info())["id"] + async with await cloner(client) as clone: + my_id = (await client.client_info())["id"] + clone_id = (await clone.client_info())["id"] + laddr = (await client.client_info())["laddr"] + resp = await client.client_kill(laddr=laddr, skipme=True) + assert resp > 0 + await clone.ping() + assert clone_id != (await clone.client_info())["id"] + assert my_id == (await client.client_info())["id"] @pytest.mark.min_server_version("7.4.0") async def test_client_kill_filter_maxage(self, client, cloner, _s): - clone = await cloner(client) - my_id = (await client.client_info())["id"] - clone_id = (await clone.client_info())["id"] - await asyncio.sleep(1) - assert await client.client_kill(maxage=1, skipme=False) >= 2 - assert clone_id != (await clone.client_info())["id"] - assert my_id != (await client.client_info())["id"] + async with await cloner(client) as clone: + my_id = (await client.client_info())["id"] + clone_id = (await clone.client_info())["id"] + await anyio.sleep(1) + assert await client.client_kill(maxage=1, skipme=False) >= 2 + assert clone_id != (await clone.client_info())["id"] + assert my_id != (await client.client_info())["id"] async def test_client_list_after_client_setname(self, client, _s): with pytest.warns(UserWarning): @@ -239,13 +242,13 @@ async def test_client_setname(self, client, _s): @pytest.mark.novalkey @pytest.mark.noredict - async def test_client_pause(self, client): + async def test_client_pause(self, client, cloner): key = "key_should_expire" - another_client = coredis.Redis() - await client.set(key, "1", px=100) - assert await client.client_pause(100) - res = await another_client.get(key) - assert not res + async with await cloner(client) as another_client: + await client.set(key, "1", px=100) + assert await client.client_pause(100) + res = await another_client.get(key) + assert not res async def test_select(self, client, _s): assert (await client.client_info())["db"] == 0 diff --git a/tests/commands/test_functions.py b/tests/commands/test_functions.py index e6f408ce3..3a365b1c9 100644 --- a/tests/commands/test_functions.py +++ b/tests/commands/test_functions.py @@ -3,7 +3,7 @@ import pytest from coredis import PureToken -from coredis.commands.function import Library +from coredis.commands.function import Library, wraps from coredis.commands.request import CommandRequest from coredis.exceptions import NotBusyError, ResponseError from coredis.typing import KeyT, RedisValueT, StringT @@ -61,11 +61,8 @@ async def simple_library(client): @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "valkey", ) @@ -126,7 +123,6 @@ async def test_dump_restore(self, client, simple_library, _s): "redis_basic", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", ) @pytest.mark.min_server_version("7.0.0") @@ -157,7 +153,6 @@ async def test_call_library_function(self, client, simple_library, _s): assert await library["return_arg"](args=(1.0, 2.0, 3.0), keys=["A"]) == 10 @pytest.mark.parametrize("client_arguments", [{"readonly": True}]) - @pytest.mark.clusteronly async def test_call_library_function_ro( self, client, simple_library, _s, client_arguments, mocker ): @@ -187,21 +182,18 @@ class Coredis(Library): def __init__(self, client): super().__init__(client, "coredis") - @Library.wraps("echo_key") + @wraps(readonly=True) def echo_key(self, key: KeyT) -> CommandRequest[StringT]: ... - @Library.wraps("return_arg") + @wraps() def return_arg(self, value: RedisValueT) -> CommandRequest[RedisValueT]: ... - @Library.wraps("default_get") - def default_get(self, key: KeyT, value: RedisValueT) -> CommandRequest[RedisValueT]: ... - - @Library.wraps("default_get", key_spec=["quay"]) - def default_get_variadic( - self, quay: str, *values: RedisValueT + @wraps() + def default_get( + self, key: KeyT, *values: RedisValueT ) -> CommandRequest[RedisValueT]: ... - @Library.wraps("hmmerge") + @wraps() def hmmerge( self, key: KeyT, **values: RedisValueT ) -> CommandRequest[list[RedisValueT]]: ... @@ -210,9 +202,9 @@ def hmmerge( assert await lib.echo_key("bar") == _s("bar") assert await lib.return_arg(1) == 10 assert await lib.default_get("bar", "fu") == _s("fu") - assert await lib.default_get_variadic("bar", "fu", "bar", "baz") == _s("fubarbaz") + assert await lib.default_get("bar", "fu", "bar", "baz") == _s("fubarbaz") assert await client.set("bar", "fubar") - assert await lib.default_get_variadic("bar", "fu", "bar", "baz") == _s("fubar") + assert await lib.default_get("bar", "fu", "bar", "baz") == _s("fubar") await client.hset("hbar", {"fu": "whut?"}) assert await lib.hmmerge("hbar", fu="bar", bar="fu", baz="fubar") == [ _s("whut?"), @@ -221,7 +213,6 @@ def hmmerge( ] @pytest.mark.parametrize("client_arguments", [{"readonly": True}]) - @pytest.mark.clusteronly async def test_subclass_wrap_ro_defaults( selfself, client, simple_library, _s, client_arguments, mocker ): @@ -229,10 +220,10 @@ class Coredis(Library): def __init__(self, client): super().__init__(client, "coredis") - @Library.wraps("echo_key") + @wraps(readonly=True) def echo_key(self, key: KeyT) -> CommandRequest[StringT]: ... - @Library.wraps("return_arg") + @wraps() def return_arg(self, value: RedisValueT) -> CommandRequest[RedisValueT]: ... fcall = mocker.spy(client, "fcall") @@ -246,7 +237,6 @@ def return_arg(self, value: RedisValueT) -> CommandRequest[RedisValueT]: ... assert fcall_ro.call_count == 1 @pytest.mark.parametrize("client_arguments", [{"readonly": True}]) - @pytest.mark.clusteronly async def test_subclass_wrap_ro_forced( selfself, client, simple_library, _s, client_arguments, mocker ): @@ -254,26 +244,18 @@ class Coredis(Library): def __init__(self, client): super().__init__(client, "coredis") - @Library.wraps("echo_key", readonly=False) + @wraps(readonly=True) def echo_key(self, key: KeyT) -> CommandRequest[StringT]: ... - @Library.wraps("echo_key", readonly=True) - def echo_key_ro(self, key: KeyT) -> CommandRequest[StringT]: ... - - @Library.wraps("return_arg", readonly=False) + @wraps(readonly=True) def return_arg(self, value: RedisValueT) -> CommandRequest[RedisValueT]: ... - @Library.wraps("return_arg", readonly=True) - def return_arg_ro(self, value: RedisValueT) -> CommandRequest[RedisValueT]: ... - fcall = mocker.spy(client, "fcall") fcall_ro = mocker.spy(client, "fcall_ro") lib = await Coredis(client) assert await lib.echo_key("bar") == _s("bar") - assert await lib.echo_key_ro("bar") == _s("bar") - assert await lib.return_arg(1) == 10 with pytest.raises(ResponseError): - await lib.return_arg_ro(1) == 10 + await lib.return_arg(1) == 10 - assert fcall.call_count == 2 + assert fcall.call_count == 0 assert fcall_ro.call_count == 2 diff --git a/tests/commands/test_generic.py b/tests/commands/test_generic.py index 88e23eaeb..8ce9a0dd8 100644 --- a/tests/commands/test_generic.py +++ b/tests/commands/test_generic.py @@ -12,11 +12,8 @@ @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "redis_cached", "redis_cluster_cached", @@ -36,16 +33,17 @@ async def test_sort_ro(self, client, cloner, _s): await client.set("score{fu}:1", "8") await client.set("score{fu}:2", "3") await client.set("score{fu}:3", "5") - assert await clone.sort_ro("a{fu}") == (_s("1"), _s("2"), _s("3"), _s("4")) - assert await clone.sort_ro("a{fu}", offset=1, count=2) == (_s("2"), _s("3")) - assert await clone.sort_ro("a{fu}", order=PureToken.DESC, offset=1, count=2) == ( - _s("3"), - _s("2"), - ) - assert await clone.sort_ro("a{fu}", alpha=True, offset=1, count=2) == ( - _s("2"), - _s("3"), - ) + async with clone: + assert await clone.sort_ro("a{fu}") == (_s("1"), _s("2"), _s("3"), _s("4")) + assert await clone.sort_ro("a{fu}", offset=1, count=2) == (_s("2"), _s("3")) + assert await clone.sort_ro("a{fu}", order=PureToken.DESC, offset=1, count=2) == ( + _s("3"), + _s("2"), + ) + assert await clone.sort_ro("a{fu}", alpha=True, offset=1, count=2) == ( + _s("2"), + _s("3"), + ) async def test_sort_limited(self, client, _s): await client.rpush("a", ["3", "2", "1", "4"]) @@ -242,58 +240,33 @@ async def test_dump_and_restore_and_replace(self, client, _s): @pytest.mark.novalkey @pytest.mark.noredict async def test_migrate_single_key_with_auth(self, client, redis_auth, _s): - auth_connection = await redis_auth.connection_pool.get_connection() - await client.set("a", "1") - - with pytest.raises(DataError): - await client.migrate("172.17.0.1", auth_connection.port, 0, 100) + async with redis_auth.connection_pool.acquire() as auth_connection: + await client.set("a", "1") - assert not await client.migrate("172.17.0.1", auth_connection.port, 0, 100, "b") - assert await client.migrate("172.17.0.1", auth_connection.port, 0, 100, "a", auth="sekret") - assert await redis_auth.get("a") == "1" - await client.set("b", "2") - assert await client.migrate( - "172.17.0.1", - auth_connection.port, - 0, - 100, - "b", - username="default", - password="sekret", - ) - assert await redis_auth.get("b") == "2" - assert not await client.get("a") - assert not await client.get("b") + with pytest.raises(DataError): + await client.migrate("172.17.0.1", auth_connection.port, 0, 100) - await client.set("c", "3") - assert await client.migrate( - "172.17.0.1", - auth_connection.port, - 0, - 100, - "c", - username="default", - password="sekret", - copy=True, - ) - assert await client.get("c") == _s(3) - assert await redis_auth.get("c") == "3" - await client.set("c", 4) - assert await client.migrate( - "172.17.0.1", - auth_connection.port, - 0, - 100, - "c", - username="default", - password="sekret", - copy=True, - replace=True, - ) - assert await redis_auth.get("c") == "4" + assert not await client.migrate("172.17.0.1", auth_connection.port, 0, 100, "b") + assert await client.migrate( + "172.17.0.1", auth_connection.port, 0, 100, "a", auth="sekret" + ) + assert await redis_auth.get("a") == "1" + await client.set("b", "2") + assert await client.migrate( + "172.17.0.1", + auth_connection.port, + 0, + 100, + "b", + username="default", + password="sekret", + ) + assert await redis_auth.get("b") == "2" + assert not await client.get("a") + assert not await client.get("b") - with pytest.raises(ResponseError, match="BUSYKEY"): - await client.migrate( + await client.set("c", "3") + assert await client.migrate( "172.17.0.1", auth_connection.port, 0, @@ -303,31 +276,58 @@ async def test_migrate_single_key_with_auth(self, client, redis_auth, _s): password="sekret", copy=True, ) - await redis_auth.flushall() - with pytest.raises(ResponseError, match="WRONGPASS"): - await client.migrate( + assert await client.get("c") == _s(3) + assert await redis_auth.get("c") == "3" + await client.set("c", 4) + assert await client.migrate( "172.17.0.1", auth_connection.port, 0, 100, "c", - auth="Sekrets", + username="default", + password="sekret", + copy=True, + replace=True, ) + assert await redis_auth.get("c") == "4" + + with pytest.raises(ResponseError, match="BUSYKEY"): + await client.migrate( + "172.17.0.1", + auth_connection.port, + 0, + 100, + "c", + username="default", + password="sekret", + copy=True, + ) + await redis_auth.flushall() + with pytest.raises(ResponseError, match="WRONGPASS"): + await client.migrate( + "172.17.0.1", + auth_connection.port, + 0, + 100, + "c", + auth="Sekrets", + ) @pytest.mark.nocluster @pytest.mark.novalkey @pytest.mark.noredict async def test_migrate_multiple_keys_with_auth(self, client, redis_auth, _s): - auth_connection = await redis_auth.connection_pool.get_connection() - await client.set("a", "1") - await client.set("c", "2") - assert not await client.migrate("172.17.0.1", auth_connection.port, 0, 100, "d", "b") - assert await client.migrate( - "172.17.0.1", auth_connection.port, 0, 100, "a", "c", auth="sekret" - ) + async with redis_auth.connection_pool.acquire() as auth_connection: + await client.set("a", "1") + await client.set("c", "2") + assert not await client.migrate("172.17.0.1", auth_connection.port, 0, 100, "d", "b") + assert await client.migrate( + "172.17.0.1", auth_connection.port, 0, 100, "a", "c", auth="sekret" + ) - assert await redis_auth.get("a") == "1" - assert await redis_auth.get("c") == "2" + assert await redis_auth.get("a") == "1" + assert await redis_auth.get("c") == "2" @pytest.mark.nocluster async def test_move(self, client, cloner, _s): @@ -335,7 +335,8 @@ async def test_move(self, client, cloner, _s): await client.set("foo", 1) assert await client.move("foo", 1) assert not await client.get("foo") - assert await clone.get("foo") == _s(1) + async with clone: + assert await clone.get("foo") == _s(1) async def test_copy(self, client, _s): await client.set("a{foo}", "foo") @@ -354,7 +355,8 @@ async def test_copy_different_db(self, client, cloner, _s): await client.set("foo", 1) assert await client.copy("foo", "bar", db=1) assert not await client.get("bar") - assert await clone.get("bar") == _s(1) + async with clone: + assert await clone.get("bar") == _s(1) @pytest.mark.min_server_version("7.0.0") async def test_object_encoding_listpack(self, client, _s): @@ -373,6 +375,7 @@ async def test_object_freq(self, client, _s): assert isinstance(await client.object_freq("a"), int) @pytest.mark.novalkey + @pytest.mark.noredict async def test_object_idletime(self, client, _s): await client.set("a", "foo") assert isinstance(await client.object_idletime("a"), int) diff --git a/tests/commands/test_geo.py b/tests/commands/test_geo.py index 50f51edd4..f81cc161e 100644 --- a/tests/commands/test_geo.py +++ b/tests/commands/test_geo.py @@ -2,24 +2,21 @@ import pytest -from coredis import PureToken +from coredis import PureToken, Redis from coredis.exceptions import CommandSyntaxError, DataError from tests.conftest import server_deprecation_warning, targets @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "valkey", "redict", ) class TestGeo: - async def test_geoadd(self, client, _s): + async def test_geoadd(self, client: Redis[str], _s): values = [ (2.1909389952632, 41.433791470673, "place1"), ( @@ -265,7 +262,7 @@ async def test_geosearch_sort(self, client, _s): order=PureToken.DESC, ) == (_s("place2"), _s("place1")) - async def test_geosearch_with(self, client, _s): + async def test_geosearch_with(self, client: Redis[str], _s): values = [ (2.1909389952632, 41.433791470673, "place1"), ( diff --git a/tests/commands/test_hash.py b/tests/commands/test_hash.py index c4eba7738..261161a29 100644 --- a/tests/commands/test_hash.py +++ b/tests/commands/test_hash.py @@ -1,9 +1,9 @@ from __future__ import annotations -import asyncio import datetime import time +import anyio import pytest from coredis import PureToken @@ -13,11 +13,8 @@ @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "redis_cached", "redis_cluster_cached", @@ -77,7 +74,7 @@ async def test_hexpire(self, client, _s): assert (2, 2, -2) == await client.hexpire( "a", datetime.timedelta(seconds=0), ["1", "3", "5"], PureToken.LT ) - await asyncio.sleep(1) + await anyio.sleep(1) assert {_s("2"): _s("2")} == await client.hgetall(_s("a")) @pytest.mark.min_server_version("7.4.0") @@ -98,7 +95,7 @@ async def test_hexpireat(self, client, _s, redis_server_time): ["1", "3", "5"], PureToken.LT, ) - await asyncio.sleep(1) + await anyio.sleep(1) assert {_s("2"): _s("2")} == await client.hgetall(_s("a")) @pytest.mark.min_server_version("7.4.0") @@ -136,7 +133,7 @@ async def test_hpexpire(self, client, _s): assert (2, 2, -2) == await client.hpexpire( "a", datetime.timedelta(milliseconds=0), ["1", "3", "5"], PureToken.LT ) - await asyncio.sleep(1) + await anyio.sleep(1) assert {_s("2"): _s("2")} == await client.hgetall(_s("a")) @pytest.mark.min_server_version("7.4.0") @@ -163,7 +160,7 @@ async def test_hpexpireat(self, client, _s, redis_server_time): ["1", "3", "5"], PureToken.LT, ) - await asyncio.sleep(1) + await anyio.sleep(1) assert {_s("2"): _s("2")} == await client.hgetall(_s("a")) @pytest.mark.min_server_version("7.4.0") diff --git a/tests/commands/test_hyperloglog.py b/tests/commands/test_hyperloglog.py index f3164d14c..5f068f8f4 100644 --- a/tests/commands/test_hyperloglog.py +++ b/tests/commands/test_hyperloglog.py @@ -7,11 +7,8 @@ @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "valkey", "redict", diff --git a/tests/commands/test_list.py b/tests/commands/test_list.py index ea116472c..1f26b9fb5 100644 --- a/tests/commands/test_list.py +++ b/tests/commands/test_list.py @@ -1,20 +1,17 @@ from __future__ import annotations -import asyncio - +import anyio import pytest from coredis import PureToken +from coredis._concurrency import gather from tests.conftest import server_deprecation_warning, targets @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "redis_cached", "redis_cluster_cached", @@ -277,11 +274,12 @@ async def test_blmpop(self, client, cloner, _s): assert result[1] == [_s("6")] async def _delayadd(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) clone = await cloner(client) - return await clone.rpush("a{foo}", ["42"]) + async with clone: + return await clone.rpush("a{foo}", ["42"]) - result = await asyncio.gather(client.blmpop(["a{foo}"], 1, PureToken.LEFT), _delayadd()) + result = await gather(client.blmpop(["a{foo}"], 1, PureToken.LEFT), _delayadd()) assert result[0][1] == [_s("42")] async def test_blmove(self, client, _s): diff --git a/tests/commands/test_server.py b/tests/commands/test_server.py index ae3611416..19ef50169 100644 --- a/tests/commands/test_server.py +++ b/tests/commands/test_server.py @@ -1,8 +1,8 @@ from __future__ import annotations -import asyncio import datetime +import anyio import pytest from pytest import approx @@ -15,11 +15,8 @@ @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "valkey", "redict", ) @@ -126,13 +123,14 @@ async def test_flushall(self, client, cloner, _s, mode): await client.set("a", "foo") await client.set("b", "bar") db1 = await cloner(client, connection_kwargs={"db": 1}) - await db1.set("a", "foo") - await db1.set("b", "bar") - assert len(await client.keys()) == 2 - assert len(await db1.keys()) == 2 - assert await client.flushall(mode) - assert len(await client.keys()) == 0 - assert len(await db1.keys()) == 0 + async with db1: + await db1.set("a", "foo") + await db1.set("b", "bar") + assert len(await client.keys()) == 2 + assert len(await db1.keys()) == 2 + assert await client.flushall(mode) + assert len(await client.keys()) == 0 + assert len(await db1.keys()) == 0 @pytest.mark.parametrize( "mode", @@ -337,7 +335,7 @@ async def test_swapdb(self, client, _s): @pytest.mark.xfail async def test_quit(self, client): assert await client.quit() - await asyncio.sleep(0.1) + await anyio.sleep(0.1) assert not client.connection_pool.peek_available().is_connected diff --git a/tests/commands/test_set.py b/tests/commands/test_set.py index 4e69a0513..715c4370d 100644 --- a/tests/commands/test_set.py +++ b/tests/commands/test_set.py @@ -7,11 +7,8 @@ @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "redis_cached", "redis_cluster_cached", diff --git a/tests/commands/test_sorted_set.py b/tests/commands/test_sorted_set.py index 59d00d721..6c11c943f 100644 --- a/tests/commands/test_sorted_set.py +++ b/tests/commands/test_sorted_set.py @@ -1,21 +1,18 @@ from __future__ import annotations -import asyncio - +import anyio import pytest from coredis import PureToken +from coredis._concurrency import gather from coredis.exceptions import CommandSyntaxError, DataError from tests.conftest import server_deprecation_warning, targets @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "redis_cached", "redis_cluster_cached", @@ -717,10 +714,11 @@ async def test_bzmpop(self, client, cloner, _s): async def _delayadd(): clone = await cloner(client) - await asyncio.sleep(0.1) - return await clone.zadd("a{foo}", dict(a1=42)) + async with clone: + await anyio.sleep(0.1) + return await clone.zadd("a{foo}", dict(a1=42)) - result = await asyncio.gather(client.bzmpop(["a{foo}"], 1, PureToken.MIN), _delayadd()) + result = await gather(client.bzmpop(["a{foo}"], 1, PureToken.MIN), _delayadd()) assert result[0][1] == ((_s("a1"), 42.0),) @pytest.mark.nodragonfly diff --git a/tests/commands/test_streams.py b/tests/commands/test_streams.py index 71c3da860..af0aed309 100644 --- a/tests/commands/test_streams.py +++ b/tests/commands/test_streams.py @@ -24,11 +24,8 @@ async def get_stream_message(client, stream, message_id): @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "valkey", "redict", diff --git a/tests/commands/test_string.py b/tests/commands/test_string.py index 0cb01052e..95cda105f 100644 --- a/tests/commands/test_string.py +++ b/tests/commands/test_string.py @@ -11,11 +11,8 @@ @targets( "redis_basic", - "redis_basic_resp2", - "redis_basic_blocking", "redis_basic_raw", "redis_cluster", - "redis_cluster_blocking", "redis_cluster_raw", "redis_cached", "redis_cluster_cached", diff --git a/tests/commands/test_vector_sets.py b/tests/commands/test_vector_sets.py index 5f9214ef2..b610621a4 100644 --- a/tests/commands/test_vector_sets.py +++ b/tests/commands/test_vector_sets.py @@ -34,7 +34,6 @@ async def sample_data(client): "redis_basic_raw", "redis_cluster", "redis_cluster_raw", - "redis_basic_resp2", ) @pytest.mark.min_server_version("8.0.0") class TestVectorSets: diff --git a/tests/conftest.py b/tests/conftest.py index 7d7077cce..4bcab66a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import socket import time from functools import total_ordering +from typing import Any, Generator import pytest import redis @@ -14,13 +15,10 @@ from pytest_lazy_fixtures import lf import coredis -import coredis.connection -import coredis.experimental -import coredis.parser import coredis.sentinel -from coredis import BlockingConnectionPool from coredis._utils import EncodingInsensitiveDict, b, hash_slot, nativestr -from coredis.cache import TrackingCache +from coredis.cache import LRUCache +from coredis.client.basic import Redis from coredis.credentials import UserPassCredentialProvider from coredis.response._callbacks import NoopCallback from coredis.typing import ( @@ -58,12 +56,18 @@ } -@pytest.fixture(scope="session", autouse=True) -def uvloop(): - if os.environ.get("COREDIS_UVLOOP") == "True": - import uvloop +def get_backends(): + backend = os.environ.get("COREDIS_ANYIO_BACKEND", None) or "asyncio" + if backend == "all": + return "asyncio", "trio" + elif backend == "asyncio": + return (("asyncio", {"use_uvloop": os.environ.get("COREDIS_UVLOOP", None) == "True"}),) + return (backend,) - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +@pytest.fixture(scope="module", params=get_backends()) +def anyio_backend(request: Any) -> Any: + return request.param @total_ordering @@ -81,7 +85,7 @@ def __lt__(self, other): return True -async def get_module_versions(client): +async def get_module_versions(client: Redis): if str(client) not in MODULE_VERSIONS: MODULE_VERSIONS[str(client)] = {} try: @@ -107,10 +111,10 @@ async def get_version(client): if str(client) not in REDIS_VERSIONS: try: if isinstance(client, coredis.RedisCluster): - await client node = list(client.primaries).pop() - version_string = (await node.info())["redis_version"] - REDIS_VERSIONS[str(client)] = version.parse(version_string) + async with node: + version_string = (await node.info())["redis_version"] + REDIS_VERSIONS[str(client)] = version.parse(version_string) elif isinstance(client, coredis.sentinel.Sentinel): version_string = (await client.sentinels[0].info())["redis_version"] REDIS_VERSIONS[str(client)] = version.parse(version_string) @@ -132,9 +136,10 @@ async def get_version(client): return REDIS_VERSIONS[str(client)] -async def check_test_constraints(request, client, protocol=3): - await get_version(client) - await get_module_versions(client) +async def check_test_constraints(request, client): + async with client: + await get_version(client) + await get_module_versions(client) client_version = REDIS_VERSIONS[str(client)] for marker in request.node.iter_markers(): if marker.name == "min_python" and marker.args: @@ -186,12 +191,6 @@ async def check_test_constraints(request, client, protocol=3): if marker.name == "os" and not marker.args[0].lower() == platform.system().lower(): return pytest.skip(f"Skipped for {platform.system()}") - if protocol == 3 and client_version < version.parse("6.0.0"): - return pytest.skip(f"Skipped RESP3 for {client_version}") - - if marker.name == "noresp3" and protocol == 3: - return pytest.skip("Skipped for RESP3") - if marker.name == "nodragonfly" and SERVER_TYPES.get(str(client)) == "dragonfly": return pytest.skip("Skipped for Dragonfly") @@ -223,7 +222,7 @@ async def set_default_test_config(client, variant=None): await client.acl_log(reset=True) -def get_client_test_args(request): +def get_client_test_args(request) -> dict[str, int]: if "client_arguments" in request.fixturenames: return request.getfixturevalue("client_arguments") @@ -255,14 +254,18 @@ async def remapped_slots(client, request): moves[slot] = destinations[slot].node_id try: for slot in moves.keys(): - [await p.cluster_setslot(slot, node=moves[slot]) for p in client.primaries] + for p in client.primaries: + async with p: + await p.cluster_setslot(slot, node=moves[slot]) yield finally: if originals: await client.flushall() for slot in originals.keys(): - [await p.cluster_setslot(slot, node=originals[slot]) for p in client.primaries] + for p in client.primaries: + async with p: + await p.cluster_setslot(slot, node=originals[slot]) def check_redis_cluster_ready(host, port): @@ -420,11 +423,10 @@ def redis_stack_cluster_server(docker_services): @pytest.fixture(scope="session") -def redis_sentinel_server(docker_services): +def redis_sentinel_server(docker_services) -> Generator[tuple[str, int], Any, None]: docker_services.start("redis-sentinel") docker_services.wait_for_service("redis-sentinel", 26379, ping_socket) - - yield ["localhost", 26379] + yield "localhost", 26379 @pytest.fixture(scope="session") @@ -469,57 +471,17 @@ def redict_server(docker_services): @pytest.fixture async def redis_basic(redis_basic_server, request): - client = coredis.Redis( - "localhost", 6379, decode_responses=True, **get_client_test_args(request) - ) - await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() - - -@pytest.fixture -async def redis_basic_resp2(redis_basic_server, request): client = coredis.Redis( "localhost", 6379, decode_responses=True, - protocol_version=2, **get_client_test_args(request), ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() - - -@pytest.fixture -async def redis_basic_blocking(redis_basic_server, request): - client = coredis.Redis( - "localhost", - 6379, - decode_responses=True, - connection_pool=BlockingConnectionPool( - host="localhost", - port=6379, - decode_responses=True, - **get_client_test_args(request), - ), - **get_client_test_args(request), - ) - await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture @@ -528,29 +490,25 @@ async def redis_stack(redis_stack_server, request): *redis_stack_server, decode_responses=True, **get_client_test_args(request) ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture async def redis_stack_raw(redis_stack_server, request): client = coredis.Redis(*redis_stack_server, **get_client_test_args(request)) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture async def redis_stack_cached(redis_stack_server, request): - cache = TrackingCache(max_size_bytes=-1) + cache = LRUCache() client = coredis.Redis( *redis_stack_server, decode_responses=True, @@ -558,31 +516,22 @@ async def redis_stack_cached(redis_stack_server, request): **get_client_test_args(request), ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - client.connection_pool.disconnect() - cache.shutdown() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture async def redis_basic_raw(redis_basic_server, request): - client = coredis.Redis( - "localhost", - 6379, - decode_responses=False, - ) - await check_test_constraints(request, client) client = coredis.Redis( "localhost", 6379, decode_responses=False, **get_client_test_args(request) ) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + await check_test_constraints(request, client) + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture @@ -597,12 +546,10 @@ async def redis_ssl(redis_ssl_server, request): storage_url, decode_responses=True, **get_client_test_args(request) ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture @@ -612,12 +559,10 @@ async def redis_ssl_no_client_auth(redis_ssl_server_no_client_auth, request): storage_url, decode_responses=True, **get_client_test_args(request) ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture @@ -628,12 +573,10 @@ async def redis_auth(redis_auth_server, request): **get_client_test_args(request), ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture @@ -646,12 +589,10 @@ async def redis_auth_cred_provider(redis_auth_server, request): **get_client_test_args(request), ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture @@ -662,17 +603,15 @@ async def redis_uds(redis_uds_server, request): **get_client_test_args(request), ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture async def redis_cached(redis_basic_server, request): - cache = TrackingCache(max_size_bytes=-1) + cache = LRUCache() client = coredis.Redis( "localhost", 6379, @@ -681,13 +620,10 @@ async def redis_cached(redis_basic_server, request): **get_client_test_args(request), ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client) - - yield client - - client.connection_pool.disconnect() - cache.shutdown() + async with client: + await client.flushall() + await set_default_test_config(client) + yield client @pytest.fixture @@ -699,17 +635,16 @@ async def redis_cluster(redis_cluster_server, request): **get_client_test_args(request), ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() + async with cluster: + await cluster.flushall() + await cluster.flushdb() - for primary in cluster.primaries: - await set_default_test_config(primary) + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) - async with remapped_slots(cluster, request): - yield cluster - - cluster.connection_pool.disconnect() + async with remapped_slots(cluster, request): + yield cluster @pytest.fixture @@ -722,17 +657,16 @@ async def redis_cluster_auth(redis_cluster_auth_server, request): **get_client_test_args(request), ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() - - for primary in cluster.primaries: - await set_default_test_config(primary) + async with cluster: + await cluster.flushall() + await cluster.flushdb() - async with remapped_slots(cluster, request): - yield cluster + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) - cluster.connection_pool.disconnect() + async with remapped_slots(cluster, request): + yield cluster @pytest.fixture @@ -745,44 +679,16 @@ async def redis_cluster_auth_cred_provider(redis_cluster_auth_server, request): **get_client_test_args(request), ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() + async with cluster: + await cluster.flushall() + await cluster.flushdb() - for primary in cluster.primaries: - await set_default_test_config(primary) + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) - async with remapped_slots(cluster, request): - yield cluster - - cluster.connection_pool.disconnect() - - -@pytest.fixture -async def redis_cluster_blocking(redis_cluster_server, request): - pool = coredis.BlockingClusterConnectionPool( - startup_nodes=[{"host": "localhost", "port": 7000}], - max_connections=32, - decode_responses=True, - **get_client_test_args(request), - ) - cluster = coredis.RedisCluster( - connection_pool=pool, - decode_responses=True, - **get_client_test_args(request), - ) - await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() - - for primary in cluster.primaries: - await set_default_test_config(primary) - - async with remapped_slots(cluster, request): - yield cluster - - cluster.connection_pool.disconnect() + async with remapped_slots(cluster, request): + yield cluster @pytest.fixture @@ -794,17 +700,16 @@ async def redis_cluster_noreplica(redis_cluster_noreplica_server, request): **get_client_test_args(request), ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() + async with cluster: + await cluster.flushall() + await cluster.flushdb() - for primary in cluster.primaries: - await set_default_test_config(primary) + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) - async with remapped_slots(cluster, request): - yield cluster - - cluster.connection_pool.disconnect() + async with remapped_slots(cluster, request): + yield cluster @pytest.fixture @@ -820,20 +725,19 @@ async def redis_cluster_ssl(redis_ssl_cluster_server, request): ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() - - for primary in cluster.primaries: - await set_default_test_config(primary) - yield cluster + async with cluster: + await cluster.flushall() + await cluster.flushdb() - cluster.connection_pool.disconnect() + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) + yield cluster @pytest.fixture async def redis_cluster_cached(redis_cluster_server, request): - cache = TrackingCache(max_size_bytes=-1) + cache = LRUCache() cluster = coredis.RedisCluster( "localhost", 7000, @@ -842,16 +746,14 @@ async def redis_cluster_cached(redis_cluster_server, request): **get_client_test_args(request), ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() + async with cluster: + await cluster.flushall() + await cluster.flushdb() - for primary in cluster.primaries: - await set_default_test_config(primary) - yield cluster - - cluster.connection_pool.disconnect() - cache.shutdown() + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) + yield cluster @pytest.fixture @@ -862,15 +764,14 @@ async def redis_cluster_raw(redis_cluster_server, request): **get_client_test_args(request), ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() - - for primary in cluster.primaries: - await set_default_test_config(primary) - yield cluster + async with cluster: + await cluster.flushall() + await cluster.flushdb() - cluster.connection_pool.disconnect() + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) + yield cluster @pytest.fixture @@ -881,33 +782,28 @@ async def redis_stack_cluster(redis_stack_cluster_server, request): **get_client_test_args(request), ) await check_test_constraints(request, cluster) - await cluster - await cluster.flushall() - await cluster.flushdb() - - for primary in cluster.primaries: - await set_default_test_config(primary) + async with cluster: + await cluster.flushall() + await cluster.flushdb() - async with remapped_slots(cluster, request): - yield cluster + for primary in cluster.primaries: + async with primary: + await set_default_test_config(primary) - cluster.connection_pool.disconnect() + async with remapped_slots(cluster, request): + yield cluster @pytest.fixture -async def redis_sentinel(redis_sentinel_server, request): - sentinel = coredis.sentinel.Sentinel( - [redis_sentinel_server], - sentinel_kwargs={}, +async def redis_sentinel(redis_sentinel_server: tuple[str, int], request): + sentinel = coredis.Sentinel( + sentinels=[redis_sentinel_server], + sentinel_kwargs={"connect_timeout": 1}, decode_responses=True, **get_client_test_args(request), ) - master = sentinel.primary_for("mymaster") - await check_test_constraints(request, master) - await set_default_test_config(sentinel) - await master.flushall() - - return sentinel + async with sentinel: + yield sentinel @pytest.fixture @@ -917,29 +813,13 @@ async def redis_sentinel_raw(redis_sentinel_server, request): sentinel_kwargs={}, **get_client_test_args(request), ) - master = sentinel.primary_for("mymaster") - await check_test_constraints(request, master) - await set_default_test_config(sentinel) - await master.flushall() - - return sentinel - - -@pytest.fixture -async def redis_sentinel_resp2(redis_sentinel_server, request): - sentinel = coredis.sentinel.Sentinel( - [redis_sentinel_server], - sentinel_kwargs={}, - decode_responses=True, - protocol_version=2, - **get_client_test_args(request), - ) - master = sentinel.primary_for("mymaster") - await check_test_constraints(request, master) - await set_default_test_config(sentinel) - await master.flushall() - - return sentinel + async with sentinel: + master = sentinel.primary_for("mymaster") + await check_test_constraints(request, master) + async with master: + await set_default_test_config(sentinel) + await master.flushall() + yield sentinel @pytest.fixture @@ -951,13 +831,15 @@ async def redis_sentinel_auth(redis_sentinel_auth_server, request): decode_responses=True, **get_client_test_args(request), ) - master = sentinel.primary_for("mymaster") - await check_test_constraints(request, master) - await set_default_test_config(sentinel) - await master.flushall() - await asyncio.sleep(0.1) + async with sentinel: + master = sentinel.primary_for("mymaster") + await check_test_constraints(request, master) + async with master: + await set_default_test_config(sentinel) + await master.flushall() + await asyncio.sleep(0.1) - return sentinel + yield sentinel @pytest.fixture @@ -969,13 +851,15 @@ async def redis_sentinel_auth_cred_provider(redis_sentinel_auth_server, request) decode_responses=True, **get_client_test_args(request), ) - master = sentinel.primary_for("mymaster") - await check_test_constraints(request, master) - await set_default_test_config(sentinel) - await master.flushall() - await asyncio.sleep(0.1) + async with sentinel: + master = sentinel.primary_for("mymaster") + await check_test_constraints(request, master) + async with master: + await set_default_test_config(sentinel) + await master.flushall() + await asyncio.sleep(0.1) - return sentinel + yield sentinel @pytest.fixture @@ -1041,12 +925,10 @@ async def dragonfly(dragonfly_server, request): **get_client_test_args(request), ) await check_test_constraints(request, client) - await client.flushall() - await set_default_test_config(client, variant="dragonfly") - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client, variant="dragonfly") + yield client @pytest.fixture @@ -1057,13 +939,11 @@ async def valkey(valkey_server, request): decode_responses=True, **get_client_test_args(request), ) - await client.flushall() await check_test_constraints(request, client) - await set_default_test_config(client, variant="valkey") - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client, variant="valkey") + yield client @pytest.fixture @@ -1074,13 +954,11 @@ async def redict(redict_server, request): decode_responses=True, **get_client_test_args(request), ) - await client.flushall() await check_test_constraints(request, client) - await set_default_test_config(client, variant="redict") - - yield client - - client.connection_pool.disconnect() + async with client: + await client.flushall() + await set_default_test_config(client, variant="redict") + yield client @pytest.fixture(scope="session") @@ -1111,7 +989,6 @@ def module_targets(): ) >= version.parse("8.0.0"): targets = [ "redis_basic", - "redis_basic_resp2", "redis_basic_raw", "redis_cached", "redis_cluster", @@ -1129,10 +1006,10 @@ def module_targets(): def redis_server_time(): async def _get_server_time(client): if isinstance(client, coredis.RedisCluster): - await client node = list(client.primaries).pop() - return await node.time() + async with node: + return await node.time() elif isinstance(client, coredis.Redis): return await client.time() @@ -1162,13 +1039,12 @@ def str_or_bytes(value): @pytest.fixture def cloner(): - async def _cloner(client, initialize=True, connection_kwargs={}, **kwargs): + async def _cloner(client, connection_kwargs={}, **kwargs): if isinstance(client, coredis.client.Redis): c_kwargs = client.connection_pool.connection_kwargs c_kwargs.update(connection_kwargs) c = client.__class__( decode_responses=client.decode_responses, - protocol_version=client.protocol_version, encoding=client.encoding, connection_pool=client.connection_pool.__class__(**c_kwargs), **kwargs, @@ -1178,14 +1054,9 @@ async def _cloner(client, initialize=True, connection_kwargs={}, **kwargs): client.connection_pool.nodes.startup_nodes[0].host, client.connection_pool.nodes.startup_nodes[0].port, decode_responses=client.decode_responses, - protocol_version=client.protocol_version, encoding=client.encoding, **kwargs, ) - - if initialize: - await c.ping() - return c return _cloner diff --git a/tests/modules/test_autocomplete.py b/tests/modules/test_autocomplete.py index 3aed7b29e..098b24d59 100644 --- a/tests/modules/test_autocomplete.py +++ b/tests/modules/test_autocomplete.py @@ -5,6 +5,7 @@ import pytest from coredis import Redis +from coredis._concurrency import gather from coredis.modules.response.types import AutocompleteSuggestion from tests.conftest import module_targets @@ -62,15 +63,17 @@ async def test_suggestions(self, client: Redis, _s): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool, _s): - p = await client.pipeline(transaction=transaction) - p.autocomplete.sugadd("suggest", "hello", 1) - p.autocomplete.sugadd("suggest", "hello world", 1) - p.autocomplete.suglen("suggest") - p.autocomplete.sugget("suggest", "hel") - p.autocomplete.sugdel("suggest", "hello") - p.autocomplete.sugdel("suggest", "hello world") - p.autocomplete.suglen("suggest") - assert ( + async with client.pipeline(transaction=transaction) as p: + results = [ + p.autocomplete.sugadd("suggest", "hello", 1), + p.autocomplete.sugadd("suggest", "hello world", 1), + p.autocomplete.suglen("suggest"), + p.autocomplete.sugget("suggest", "hel"), + p.autocomplete.sugdel("suggest", "hello"), + p.autocomplete.sugdel("suggest", "hello world"), + p.autocomplete.suglen("suggest"), + ] + assert await gather(*results) == ( 1, 2, 2, @@ -81,4 +84,4 @@ async def test_pipeline(self, client: Redis, transaction: bool, _s): 1, 1, 0, - ) == await p.execute() + ) diff --git a/tests/modules/test_bloom_filter.py b/tests/modules/test_bloom_filter.py index 97966f6fe..65e11d179 100644 --- a/tests/modules/test_bloom_filter.py +++ b/tests/modules/test_bloom_filter.py @@ -1,10 +1,9 @@ from __future__ import annotations -import asyncio - import pytest from coredis import Redis +from coredis._concurrency import gather from coredis.exceptions import ResponseError from tests.conftest import module_targets @@ -16,7 +15,7 @@ async def test_reserve(self, client: Redis, _s): with pytest.raises(ResponseError): await client.bf.reserve("filter", 0.1, 1000) assert await client.bf.reserve("filter_ex", 0.1, 1000, 3) - info = await asyncio.gather( + info = await gather( client.bf.info("filter"), client.bf.info("filter_ex"), ) @@ -89,14 +88,16 @@ async def test_dump_load(self, client: Redis): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool): - p = await client.pipeline(transaction=transaction) - p.bf.add("filter", 1) - p.bf.add("filter", 2) - p.bf.exists("filter", 2) - p.bf.mexists("filter", [1, 2, 3]) - assert ( + async with client.pipeline(transaction=transaction) as p: + results = [ + p.bf.add("filter", 1), + p.bf.add("filter", 2), + p.bf.exists("filter", 2), + p.bf.mexists("filter", [1, 2, 3]), + ] + assert await gather(*results) == ( True, True, True, (True, True, False), - ) == await p.execute() + ) diff --git a/tests/modules/test_count_min_sketch.py b/tests/modules/test_count_min_sketch.py index f5059df8e..eb1a61512 100644 --- a/tests/modules/test_count_min_sketch.py +++ b/tests/modules/test_count_min_sketch.py @@ -1,10 +1,9 @@ from __future__ import annotations -import asyncio - import pytest from coredis import Redis +from coredis._concurrency import gather from coredis.exceptions import ResponseError from tests.conftest import module_targets @@ -14,7 +13,7 @@ class TestCountMinSketch: async def test_init(self, client: Redis, _s): assert await client.cms.initbydim("sketch", 2, 50) assert await client.cms.initbyprob("sketchprob", 0.042, 0.42) - infos = await asyncio.gather(client.cms.info("sketch"), client.cms.info("sketchprob")) + infos = await gather(client.cms.info("sketch"), client.cms.info("sketchprob")) assert infos[0][_s("width")] == 2 assert infos[0][_s("depth")] == 50 assert infos[1][_s("width")] == 48 @@ -98,9 +97,11 @@ async def test_merge_cluster(self, client): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool): - p = await client.pipeline(transaction=transaction) - p.cms.initbydim("sketch", 2, 50) - p.cms.incrby("sketch", {"fu": 1, "bar": 2}) - p.cms.incrby("sketch", {"fu": 3}) - p.cms.query("sketch", ["fu", "bar"]) - assert (True, (1, 2), (4,), (4, 2)) == await p.execute() + async with client.pipeline(transaction=transaction) as p: + results = [ + p.cms.initbydim("sketch", 2, 50), + p.cms.incrby("sketch", {"fu": 1, "bar": 2}), + p.cms.incrby("sketch", {"fu": 3}), + p.cms.query("sketch", ["fu", "bar"]), + ] + assert await gather(*results) == (True, (1, 2), (4,), (4, 2)) diff --git a/tests/modules/test_cuckoo_filter.py b/tests/modules/test_cuckoo_filter.py index 01a8d9afe..1d9292ccd 100644 --- a/tests/modules/test_cuckoo_filter.py +++ b/tests/modules/test_cuckoo_filter.py @@ -1,10 +1,9 @@ from __future__ import annotations -import asyncio - import pytest from coredis import Redis +from coredis._concurrency import gather from coredis.exceptions import ResponseError from tests.conftest import module_targets @@ -16,7 +15,7 @@ async def test_reserve(self, client: Redis, _s): with pytest.raises(ResponseError): await client.cf.reserve("filter", 1000) assert await client.cf.reserve("filter_bucket", 1000, 3) - info = await asyncio.gather( + info = await gather( client.cf.info("filter"), client.cf.info("filter_bucket"), ) @@ -87,12 +86,13 @@ async def test_dump_load(self, client: Redis): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool): - p = await client.pipeline(transaction=transaction) - p.cf.add("filter", 1) - p.cf.add("filter", 2) - p.cf.exists("filter", 2) - p.cf.mexists("filter", [1, 2, 3]) - p.cf.delete("filter", 2) - p.cf.exists("filter", 2) - - assert (True, True, True, (True, True, False), True, False) == await p.execute() + async with client.pipeline(transaction=transaction) as p: + results = [ + p.cf.add("filter", 1), + p.cf.add("filter", 2), + p.cf.exists("filter", 2), + p.cf.mexists("filter", [1, 2, 3]), + p.cf.delete("filter", 2), + p.cf.exists("filter", 2), + ] + assert await gather(*results) == (True, True, True, (True, True, False), True, False) diff --git a/tests/modules/test_graph.py b/tests/modules/test_graph.py index 84b616a60..44459da77 100644 --- a/tests/modules/test_graph.py +++ b/tests/modules/test_graph.py @@ -5,6 +5,7 @@ import pytest from coredis import PureToken, Redis +from coredis._concurrency import gather from coredis.exceptions import ResponseError from coredis.modules.response.types import GraphNode, GraphQueryResult from tests.conftest import module_targets @@ -220,14 +221,16 @@ async def test_slowlog_reset(self, client: Redis): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction): - p = await client.pipeline(transaction=transaction) - p.graph.query("graph", "CREATE (:Node {name: 'A'})") - p.graph.query("graph", "MATCH (n) return n") - assert ( + async with client.pipeline(transaction=transaction) as p: + results = [ + p.graph.query("graph", "CREATE (:Node {name: 'A'})"), + p.graph.query("graph", "MATCH (n) return n"), + ] + assert await gather(*results) == ( GraphQueryResult((), (), stats=ANY), GraphQueryResult( ("n",), ([GraphNode(id=0, labels={"Node"}, properties={"name": "A"})],), stats=ANY, ), - ) == await p.execute() + ) diff --git a/tests/modules/test_json.py b/tests/modules/test_json.py index c199b6b4f..d00c80b87 100644 --- a/tests/modules/test_json.py +++ b/tests/modules/test_json.py @@ -3,6 +3,7 @@ import pytest from coredis import PureToken, Redis +from coredis._concurrency import gather from coredis.exceptions import ResponseError from tests.conftest import module_targets @@ -831,25 +832,27 @@ async def test_debug_memory(self, client: Redis, seed): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool): - p = await client.pipeline(transaction=transaction) - p.json.set( - "key", - LEGACY_ROOT_PATH, - {"a": 1, "b": [2], "c": {"d": "3"}, "e": {"f": [{"g": 4, "h": True}]}}, - ) - p.json.numincrby("key", "$.a", 1) - p.json.arrappend("key", [1], "..*") - p.json.strappend("key", "bar", "..*") - p.json.toggle("key", "..*") - p.json.toggle("key", "..*") - assert ( + async with client.pipeline(transaction=transaction) as p: + results = [ + p.json.set( + "key", + LEGACY_ROOT_PATH, + {"a": 1, "b": [2], "c": {"d": "3"}, "e": {"f": [{"g": 4, "h": True}]}}, + ), + p.json.numincrby("key", "$.a", 1), + p.json.arrappend("key", [1], "..*"), + p.json.strappend("key", "bar", "..*"), + p.json.toggle("key", "..*"), + p.json.toggle("key", "..*"), + ] + assert await gather(*results) == ( True, [2], 2, 4, False, True, - ) == await p.execute() + ) assert { "a": 2, "b": [2, 1], diff --git a/tests/modules/test_search.py b/tests/modules/test_search.py index 5050b3791..f26a5cf6f 100644 --- a/tests/modules/test_search.py +++ b/tests/modules/test_search.py @@ -7,6 +7,7 @@ import pytest from coredis import PureToken, Redis +from coredis._concurrency import gather from coredis.exceptions import ResponseError from coredis.modules.response.types import ( SearchAggregationResult, @@ -608,14 +609,16 @@ async def test_pipeline(self, client: Redis, _s): on=PureToken.HASH, prefixes=["{search}:"], ) - p = await client.pipeline() - p.hset("{search}:doc:1", {"name": "hello"}) - p.hset("{search}:doc:2", {"name": "world"}) - p.search.search( - "{search}:idx", - "@name:hello", - ) - assert ( + async with client.pipeline() as p: + results = [ + p.hset("{search}:doc:1", {"name": "hello"}), + p.hset("{search}:doc:2", {"name": "world"}), + p.search.search( + "{search}:idx", + "@name:hello", + ), + ] + assert await gather(*results) == ( 1, 1, SearchResult( @@ -626,7 +629,7 @@ async def test_pipeline(self, client: Redis, _s): ), ), ), - ) == await p.execute() + ) @pytest.mark.min_module_version("search", "2.6.1") @@ -853,15 +856,18 @@ async def test_pipeline(self, client: Redis, _s): on=PureToken.HASH, prefixes=["{search}:"], ) - p = await client.pipeline() - p.hset("{search}:doc:1", {"name": "hello"}) - p.hset("{search}:doc:2", {"name": "world"}) - p.search.aggregate( - "{search}:idx", - "*", - transforms=[Group("@name", [Reduce("count", [0], "count")])], - ) - assert ( + async with client.pipeline() as p: + results = [ + p.hset("{search}:doc:1", {"name": "hello"}), + p.hset("{search}:doc:2", {"name": "world"}), + p.search.aggregate( + "{search}:idx", + "*", + transforms=[Group("@name", [Reduce("count", [0], "count")])], + ), + ] + + assert await gather(*results) == ( 1, 1, SearchAggregationResult( @@ -871,4 +877,4 @@ async def test_pipeline(self, client: Redis, _s): ], None, ), - ) == await p.execute() + ) diff --git a/tests/modules/test_tdigest.py b/tests/modules/test_tdigest.py index abee050d5..782397dd9 100644 --- a/tests/modules/test_tdigest.py +++ b/tests/modules/test_tdigest.py @@ -1,10 +1,9 @@ from __future__ import annotations -import asyncio - import pytest from coredis import Redis +from coredis._concurrency import gather from tests.conftest import module_targets @@ -14,7 +13,7 @@ class TestTdigest: async def test_create(self, client: Redis, _s): await client.tdigest.create("digest") await client.tdigest.create("digest_lowcompress", 1) - info = await asyncio.gather( + info = await gather( client.tdigest.info("digest"), client.tdigest.info("digest_lowcompress"), ) @@ -88,11 +87,13 @@ async def test_merge(self, client: Redis, _s): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool): - p = await client.pipeline(transaction=transaction) - p.tdigest.create("digest1{a}") - p.tdigest.create("digest2{a}") - p.tdigest.add("digest1{a}", [1, 2, 3]) - p.tdigest.add("digest2{a}", [4, 5, 6]) - p.tdigest.merge("digest1{a}", ["digest2{a}"]) - p.tdigest.quantile("digest1{a}", [0, 0.5, 1]) - assert (True, True, True, True, True, (1.0, 4.0, 6.0)) == await p.execute() + async with client.pipeline(transaction=transaction) as p: + results = [ + p.tdigest.create("digest1{a}"), + p.tdigest.create("digest2{a}"), + p.tdigest.add("digest1{a}", [1, 2, 3]), + p.tdigest.add("digest2{a}", [4, 5, 6]), + p.tdigest.merge("digest1{a}", ["digest2{a}"]), + p.tdigest.quantile("digest1{a}", [0, 0.5, 1]), + ] + assert await gather(*results) == (True, True, True, True, True, (1.0, 4.0, 6.0)) diff --git a/tests/modules/test_timeseries.py b/tests/modules/test_timeseries.py index bad5bfe76..7fa0287f7 100644 --- a/tests/modules/test_timeseries.py +++ b/tests/modules/test_timeseries.py @@ -1,13 +1,14 @@ from __future__ import annotations -import asyncio import math import time from datetime import datetime, timedelta +import anyio import pytest from coredis import PureToken, Redis +from coredis._concurrency import gather from tests.conftest import module_targets @@ -137,7 +138,7 @@ async def test_madd(self, client: Redis): async def test_incrby(self, client: Redis, _s): for _ in range(100): assert await client.timeseries.incrby("ts1", 1) - await asyncio.sleep(0.001) + await anyio.sleep(0.001) assert 100 == (await client.timeseries.get("ts1"))[1] assert await client.timeseries.incrby("ts2", 1.5, timestamp=5) @@ -170,7 +171,7 @@ async def test_incrby(self, client: Redis, _s): async def test_decrby(self, client: Redis, _s): for _ in range(100): assert await client.timeseries.decrby("ts1", 1) - await asyncio.sleep(0.001) + await anyio.sleep(0.001) assert -100 == (await client.timeseries.get("ts1"))[1] assert await client.timeseries.decrby("ts2", 1.5, timestamp=5) @@ -719,8 +720,10 @@ async def test_uncompressed(self, client: Redis, _s): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool): - p = await client.pipeline(transaction=transaction) - p.timeseries.create("ts") - p.timeseries.add("ts", 1, 1) - p.timeseries.get("ts") - assert (True, 1, (1, 1.0)) == await p.execute() + async with client.pipeline(transaction=transaction) as p: + results = [ + p.timeseries.create("ts"), + p.timeseries.add("ts", 1, 1), + p.timeseries.get("ts"), + ] + assert await gather(*results) == (True, 1, (1, 1.0)) diff --git a/tests/modules/test_topk.py b/tests/modules/test_topk.py index e5da04a74..b0713775d 100644 --- a/tests/modules/test_topk.py +++ b/tests/modules/test_topk.py @@ -1,10 +1,9 @@ from __future__ import annotations -import asyncio - import pytest from coredis import Redis +from coredis._concurrency import gather from tests.conftest import module_targets @@ -13,7 +12,7 @@ class TestTopK: async def test_reserve(self, client: Redis, _s): assert await client.topk.reserve("topk", 3) assert await client.topk.reserve("topkcustom", 3, 16, 14, 0.8) - infos = await asyncio.gather(client.topk.info("topk"), client.topk.info("topkcustom")) + infos = await gather(client.topk.info("topk"), client.topk.info("topkcustom")) assert infos[0][_s("width")] == 8 assert infos[0][_s("depth")] == 7 assert infos[1][_s("width")] == 16 @@ -49,8 +48,10 @@ async def test_query(self, client: Redis, _s): @pytest.mark.parametrize("transaction", [True, False]) async def test_pipeline(self, client: Redis, transaction: bool): - p = await client.pipeline(transaction=transaction) - p.topk.reserve("topk", 3) - p.topk.add("topk", ["1", "2", "3"]) - p.topk.query("topk", ["1", "2", "3"]) - assert (True, (None, None, None), (True, True, True)) == await p.execute() + async with client.pipeline(transaction=transaction) as p: + results = [ + p.topk.reserve("topk", 3), + p.topk.add("topk", ["1", "2", "3"]), + p.topk.query("topk", ["1", "2", "3"]), + ] + assert await gather(*results) == (True, (None, None, None), (True, True, True)) diff --git a/tests/recipes/credentials/test_elasticache_iam_provider.py b/tests/recipes/credentials/test_elasticache_iam_provider.py index bcb068039..99f71deee 100644 --- a/tests/recipes/credentials/test_elasticache_iam_provider.py +++ b/tests/recipes/credentials/test_elasticache_iam_provider.py @@ -2,7 +2,7 @@ from moto import mock_aws -from coredis.recipes.credentials import ElastiCacheIAMProvider +from coredis.recipes import ElastiCacheIAMProvider class TestElastiCacheIAMProvider: diff --git a/tests/recipes/locks/test_lua_lock.py b/tests/recipes/locks/test_lua_lock.py index 1db8db417..65d7b7c9b 100644 --- a/tests/recipes/locks/test_lua_lock.py +++ b/tests/recipes/locks/test_lua_lock.py @@ -2,12 +2,11 @@ import time import uuid -from unittest.mock import PropertyMock import pytest from coredis.exceptions import LockError -from coredis.recipes.locks import LuaLock +from coredis.recipes import Lock from tests.conftest import targets @@ -25,7 +24,7 @@ def lock_name(): ) class TestLock: async def test_lock(self, client, _s, lock_name): - lock = LuaLock(client, lock_name, blocking=False) + lock = Lock(client, lock_name, blocking=False) assert await lock.acquire() assert await client.get(lock_name) == _s(lock.local.get()) assert await client.ttl(lock_name) == -1 @@ -33,8 +32,8 @@ async def test_lock(self, client, _s, lock_name): assert await client.get(lock_name) is None async def test_competing_locks(self, client, lock_name): - lock1 = LuaLock(client, lock_name, blocking=False) - lock2 = LuaLock(client, lock_name, blocking=False) + lock1 = Lock(client, lock_name, blocking=False) + lock2 = Lock(client, lock_name, blocking=False) assert await lock1.acquire() assert not await lock2.acquire() await lock1.release() @@ -43,13 +42,13 @@ async def test_competing_locks(self, client, lock_name): await lock2.release() async def test_timeout(self, client, lock_name): - lock = LuaLock(client, lock_name, timeout=10, blocking=False) + lock = Lock(client, lock_name, timeout=10, blocking=False) assert await lock.acquire() assert 8 < await client.ttl(lock_name) <= 10 await lock.release() async def test_float_timeout(self, client, lock_name): - lock = LuaLock( + lock = Lock( client, lock_name, blocking=False, @@ -60,9 +59,9 @@ async def test_float_timeout(self, client, lock_name): await lock.release() async def test_blocking_timeout(self, client, lock_name): - lock1 = LuaLock(client, lock_name, blocking=False) + lock1 = Lock(client, lock_name, blocking=False) assert await lock1.acquire() - lock2 = LuaLock( + lock2 = Lock( client, lock_name, blocking_timeout=0.2, @@ -72,21 +71,10 @@ async def test_blocking_timeout(self, client, lock_name): assert (time.time() - start) > 0.2 await lock1.release() - @pytest.mark.replicated_clusteronly - async def test_lock_replication_failed(self, client, mocker, lock_name): - replication_factor = mocker.patch( - "coredis.recipes.locks.LuaLock.replication_factor", - new_callable=PropertyMock, - ) - replication_factor.return_value = 2 - lock1 = LuaLock(client, lock_name, blocking=True, blocking_timeout=1) - with pytest.warns(RuntimeWarning): - assert not await lock1.acquire() - async def test_context_manager(self, client, _s, lock_name): # blocking_timeout prevents a deadlock if the lock can't be acquired # for some reason - async with LuaLock( + async with Lock( client, lock_name, blocking_timeout=0.2, @@ -97,7 +85,7 @@ async def test_context_manager(self, client, _s, lock_name): async def test_high_sleep_raises_error(self, client, lock_name): "If sleep is higher than timeout, it should raise an error" with pytest.raises(LockError): - LuaLock( + Lock( client, lock_name, timeout=1, @@ -105,7 +93,7 @@ async def test_high_sleep_raises_error(self, client, lock_name): ) async def test_releasing_unlocked_lock_raises_error(self, client, lock_name): - lock = LuaLock( + lock = Lock( client, lock_name, ) @@ -113,7 +101,7 @@ async def test_releasing_unlocked_lock_raises_error(self, client, lock_name): await lock.release() async def test_releasing_lock_no_longer_owned_raises_error(self, client, lock_name): - lock = LuaLock(client, lock_name, blocking=False) + lock = Lock(client, lock_name, blocking=False) await lock.acquire() # manually change the token await client.set(lock_name, "a") @@ -123,7 +111,7 @@ async def test_releasing_lock_no_longer_owned_raises_error(self, client, lock_na assert lock.local.get() is None async def test_extend_lock(self, client, lock_name): - lock = LuaLock( + lock = Lock( client, lock_name, blocking=False, @@ -136,7 +124,7 @@ async def test_extend_lock(self, client, lock_name): await lock.release() async def test_extend_lock_float(self, client, lock_name): - lock = LuaLock( + lock = Lock( client, lock_name, blocking=False, @@ -149,7 +137,7 @@ async def test_extend_lock_float(self, client, lock_name): await lock.release() async def test_extending_unlocked_lock_raises_error(self, client, lock_name): - lock = LuaLock( + lock = Lock( client, lock_name, timeout=10, @@ -158,7 +146,7 @@ async def test_extending_unlocked_lock_raises_error(self, client, lock_name): await lock.extend(10) async def test_extending_lock_with_no_timeout_raises_error(self, client, lock_name): - lock = LuaLock(client, lock_name, blocking=False) + lock = Lock(client, lock_name, blocking=False) await client.flushdb() assert await lock.acquire() with pytest.raises(LockError): @@ -167,7 +155,7 @@ async def test_extending_lock_with_no_timeout_raises_error(self, client, lock_na @pytest.mark.xfail async def test_extending_lock_no_longer_owned_raises_error(self, client, lock_name): - lock = LuaLock(client, lock_name, blocking=False) + lock = Lock(client, lock_name, blocking=False) await client.flushdb() assert await lock.acquire() await client.set(lock_name, "a") diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 80c1103c1..accf3ce72 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,12 +1,10 @@ from __future__ import annotations -import asyncio - import pytest import coredis from coredis.credentials import UserPassCredentialProvider -from coredis.exceptions import AuthenticationError, ConnectionError, UnknownCommandError +from coredis.exceptions import AuthenticationError @pytest.mark.parametrize( @@ -20,8 +18,9 @@ ) async def test_invalid_authentication(redis_auth, username, password): client = coredis.Redis("localhost", 6389, username=username, password=password) - with pytest.raises(AuthenticationError): - await client.ping() + with pytest.RaisesGroup(AuthenticationError, allow_unwrapped=True, flatten_subgroups=True): + async with client: + await client.ping() @pytest.mark.parametrize( @@ -39,13 +38,15 @@ async def test_invalid_authentication_cred_provider(redis_auth_cred_provider, us 6389, credential_provider=UserPassCredentialProvider(username=username, password=password), ) - with pytest.raises(AuthenticationError): - await client.ping() + with pytest.RaisesGroup(AuthenticationError, allow_unwrapped=True, flatten_subgroups=True): + async with client: + await client.ping() async def test_valid_authentication(redis_auth): client = coredis.Redis("localhost", 6389, password="sekret") - assert await client.ping() + async with client: + assert await client.ping() async def test_valid_authentication_cred_provider(redis_auth_cred_provider): @@ -54,104 +55,13 @@ async def test_valid_authentication_cred_provider(redis_auth_cred_provider): 6389, credential_provider=UserPassCredentialProvider(password="sekret"), ) - assert await client.ping() + async with client: + assert await client.ping() async def test_valid_authentication_delayed(redis_auth): client = coredis.Redis("localhost", 6389) assert client.server_version is None - with pytest.warns(UserWarning): + async with client: await client.auth(password="sekret") - assert await client.ping() - assert client.server_version is not None - - -async def test_legacy_authentication(redis_auth, mocker): - original_request = coredis.connection.BaseConnection.create_request - - async def fake_request(self, command, *args, **kwargs): - fut = asyncio.get_running_loop().create_future() - if command == b"HELLO": - fut.set_exception(UnknownCommandError("fubar")) - return fut - else: - return await original_request(self, command, *args) - - mocker.patch.object(coredis.connection.BaseConnection, "create_request", fake_request) - - with pytest.warns(UserWarning, match="no support for the `HELLO` command"): - with pytest.raises(ConnectionError): - await coredis.Redis("localhost", 6389, password="sekret").ping() - with pytest.raises(AuthenticationError): - await coredis.Redis( - "localhost", - 6389, - username="bogus", - password="sekret", - protocol_version=2, - ).ping() - - assert ( - b"PONG" - == await coredis.Redis("localhost", 6389, password="sekret", protocol_version=2).ping() - ) - assert ( - b"PONG" - == await coredis.Redis( - "localhost", - 6389, - username="default", - password="sekret", - protocol_version=2, - ).ping() - ) - - -async def test_legacy_authentication_cred_provider(redis_auth_cred_provider, mocker): - original_request = coredis.connection.BaseConnection.create_request - - async def fake_request(self, command, *args, **kwargs): - fut = asyncio.get_running_loop().create_future() - if command == b"HELLO": - fut.set_exception(UnknownCommandError("fubar")) - return fut - else: - return await original_request(self, command, *args) - - mocker.patch.object(coredis.connection.BaseConnection, "create_request", fake_request) - - with pytest.warns(UserWarning, match="no support for the `HELLO` command"): - with pytest.raises(ConnectionError): - await coredis.Redis( - "localhost", - 6389, - credential_provider=UserPassCredentialProvider(password="sekret"), - ).ping() - with pytest.raises(AuthenticationError): - await coredis.Redis( - "localhost", - 6389, - credential_provider=UserPassCredentialProvider(username="bogus", password="sekret"), - protocol_version=2, - ).ping() - - assert ( - b"PONG" - == await coredis.Redis( - "localhost", - 6389, - credential_provider=UserPassCredentialProvider(password="sekret"), - protocol_version=2, - ).ping() - ) - assert ( - b"PONG" - == await coredis.Redis( - "localhost", - 6389, - credential_provider=UserPassCredentialProvider( - username="default", password="sekret" - ), - protocol_version=2, - ).ping() - ) + assert await client.ping() diff --git a/tests/test_cache.py b/tests/test_cache.py deleted file mode 100644 index 336dcc169..000000000 --- a/tests/test_cache.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -import coredis.client -from coredis import BaseConnection -from coredis.cache import AbstractCache, CacheStats -from coredis.typing import RedisValueT, ResponseType -from tests.conftest import targets - - -class DummyCache(AbstractCache): - def __init__(self, dummy={}): - self.dummy = dummy - - async def initialize(self, client: coredis.client.Client) -> AbstractCache: - return self - - @property - def healthy(self) -> bool: - return True - - def get(self, command: bytes, key: bytes, *args: RedisValueT) -> ResponseType: - return self.dummy[key] - - def put(self, command: bytes, key: bytes, *args: RedisValueT, value: ResponseType) -> None: - self.dummy[key] = value - - def reset(self) -> None: - self.dummy.clear() - - def invalidate(self, *keys: RedisValueT) -> None: - for key in keys: - self.dummy.pop(key, None) - - @property - def stats(self) -> CacheStats: - return CacheStats() - - @property - def confidence(self) -> float: - return 100 - - def feedback(self, command: bytes, key: bytes, *args: RedisValueT, match: bool) -> None: - pass - - def get_client_id(self, connection: BaseConnection) -> int | None: - return connection.tracking_client_id - - def shutdown(self) -> None: - self.reset() - - -@targets( - "redis_basic", - "redis_basic_blocking", - "redis_basic_raw", - "redis_cluster", - "redis_cluster_blocking", - "redis_cluster_raw", -) -class TestBasicCache: - async def test_cache_hit(self, client, cloner, _s): - cache = DummyCache({"fubar": _s("1")}) - cached = await cloner(client, cache=cache) - assert _s("1") == await cached.get("fubar") - - async def test_cache_with_no_reply(self, client, cloner, _s): - cache = DummyCache({"fubar": _s("1")}) - cached = await cloner(client, cache=cache) - assert _s("1") == await cached.get("fubar") - with cached.ignore_replies(): - assert await cached.get("fubar") is None - assert _s("1") == await cached.get("fubar") - - async def test_cache_miss(self, client, cloner, _s): - cache = DummyCache({}) - cached = await cloner(client, cache=cache) - assert not await cached.get("fubar") - assert not await cached.get("fubar") - await cached.set("fubar", 1) - assert _s("1") == await cached.get("fubar") diff --git a/tests/test_client.py b/tests/test_client.py index 794fa8941..6504bc9dc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,17 +1,16 @@ from __future__ import annotations -import asyncio +import re import ssl -from ssl import SSLError -import async_timeout +import anyio import pytest +from anyio import create_task_group, fail_after, sleep from packaging.version import Version import coredis from coredis.exceptions import ( AuthorizationError, - ConnectionError, PersistenceError, ReplicationError, UnknownCommandError, @@ -22,7 +21,6 @@ @targets( "redis_basic", - "redis_basic_blocking", "redis_basic_raw", "redis_ssl", "redis_ssl_no_client_auth", @@ -56,14 +54,14 @@ async def test_set_client_name(self, client, client_arguments): assert (await client.client_info())["name"] == "coredis" async def test_noreply_client(self, client, cloner, _s): - noreply = await cloner(client, noreply=True) - assert not await noreply.set("fubar", 1) - await asyncio.sleep(0.01) - assert await client.get("fubar") == _s("1") - assert not await noreply.delete(["fubar"]) - await asyncio.sleep(0.01) - assert not await client.get("fubar") - assert not await noreply.ping() + async with await cloner(client, noreply=True) as noreply: + assert not await noreply.set("fubar", 1) + await sleep(0.01) + assert await client.get("fubar") == _s("1") + assert not await noreply.delete(["fubar"]) + await sleep(0.01) + assert not await client.get("fubar") + assert not await noreply.ping() @pytest.mark.nodragonfly async def test_noreply_context(self, client, _s): @@ -90,41 +88,39 @@ async def test_decoding_context(self, client): with client.decoding(True, encoding="cp424"): assert "א" == await client.get("fubar") + @pytest.mark.anyio async def test_blocking_task_cancellation(self, client, _s): - awaitable = client.blpop(["nonexistent"], timeout=10) - task = asyncio.ensure_future(awaitable) - await asyncio.sleep(0.5) - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - async with async_timeout.timeout(0.1): + cancelled = False + + async def _runner(): + nonlocal cancelled + try: + return await client.blpop(["nonexistent"], 10) + except anyio.get_cancelled_exc_class(): + cancelled = True + raise + + async with create_task_group() as tg: + tg.start_soon(_runner) + await sleep(0.5) + tg.cancel_scope.cancel() + assert cancelled + with fail_after(0.1): assert _s("PONG") == await client.ping() - @pytest.mark.nodragonfly - async def test_concurrent_initialization(self, client, mocker): - assert await client.client_kill(skipme=False) - client.connection_pool.reset() - connection = await client.connection_pool.get_connection(b"set", acquire=False) - spy = mocker.spy(connection, "perform_handshake") - await asyncio.gather(*[client.set(f"fubar{i}", bytes(2**16)) for i in range(10)]) - assert spy.call_count == 1 - @targets( "redis_cluster", - "redis_cluster_blocking", ) class TestClusterClient: async def test_noreply_client(self, client, cloner, _s): - noreply = await cloner(client, noreply=True) - assert not await noreply.set("fubar", 1) - await asyncio.sleep(0.01) - assert await client.get("fubar") == _s("1") - assert not await noreply.delete(["fubar"]) - await asyncio.sleep(0.01) - assert not await client.get("fubar") + async with await cloner(client, noreply=True) as noreply: + assert not await noreply.set("fubar", 1) + await sleep(0.01) + assert await client.get("fubar") == _s("1") + assert not await noreply.delete(["fubar"]) + await sleep(0.01) + assert not await client.get("fubar") async def test_noreply_context(self, client, _s): with client.ignore_replies(): @@ -133,10 +129,12 @@ async def test_noreply_context(self, client, _s): assert await client.get("fubar") == _s(1) async def test_ensure_replication_unavailable(self, client, _s, user_client): - no_perm_client = await user_client("testuser", "on", "allkeys", "+@all", "-WAIT") - with pytest.raises(AuthorizationError): - with no_perm_client.ensure_replication(1): - assert await no_perm_client.set("fubar", 1) + async with await user_client( + "testuser", "on", "allkeys", "+@all", "-WAIT" + ) as no_perm_client: + with pytest.raises(AuthorizationError): + with no_perm_client.ensure_replication(1): + assert await no_perm_client.set("fubar", 1) async def test_ensure_replication(self, client, _s): with client.ensure_replication(1): @@ -149,10 +147,12 @@ async def test_ensure_replication(self, client, _s): @pytest.mark.min_server_version("7.1.240") async def test_ensure_persistence_unavailable(self, client, _s, user_client): - no_perm_client = await user_client("testuser", "on", "allkeys", "+@all", "-WAITAOF") - with pytest.raises(AuthorizationError): - with no_perm_client.ensure_persistence(1, 1, 2000): - await no_perm_client.set("fubar", 1) + async with await user_client( + "testuser", "on", "allkeys", "+@all", "-WAITAOF" + ) as no_perm_client: + with pytest.raises(AuthorizationError): + with no_perm_client.ensure_persistence(1, 1, 2000): + await no_perm_client.set("fubar", 1) @pytest.mark.min_server_version("7.1.240") async def test_ensure_persistence(self, client, _s): @@ -174,48 +174,48 @@ async def test_decoding_context(self, client): class TestSSL: async def test_explicit_ssl_parameters(self, redis_ssl_server): - client = coredis.Redis( + async with coredis.Redis( port=8379, ssl=True, ssl_keyfile="./tests/tls/client.key", ssl_certfile="./tests/tls/client.crt", ssl_ca_certs="./tests/tls/ca.crt", - ) - assert await client.ping() == b"PONG" + ) as client: + assert await client.ping() == b"PONG" async def test_explicit_ssl_context(self, redis_ssl_server): context = ssl.create_default_context() context.check_hostname = False context.verify_mode = ssl.CERT_NONE context.load_cert_chain(certfile="./tests/tls/client.crt", keyfile="./tests/tls/client.key") - client = coredis.Redis( + async with coredis.Redis( port=8379, ssl_context=context, - ) - assert await client.ping() == b"PONG" + ) as client: + assert await client.ping() == b"PONG" async def test_cluster_explicit_ssl_parameters(self, redis_ssl_cluster_server): - client = coredis.RedisCluster( + async with coredis.RedisCluster( "localhost", port=8301, ssl=True, ssl_keyfile="./tests/tls/client.key", ssl_certfile="./tests/tls/client.crt", ssl_ca_certs="./tests/tls/ca.crt", - ) - assert await client.ping() == b"PONG" + ) as client: + assert await client.ping() == b"PONG" async def test_cluster_explicit_ssl_context(self, redis_ssl_cluster_server): context = ssl.create_default_context() context.check_hostname = False context.verify_mode = ssl.CERT_NONE context.load_cert_chain(certfile="./tests/tls/client.crt", keyfile="./tests/tls/client.key") - client = coredis.RedisCluster( + async with coredis.RedisCluster( "localhost", 8301, ssl_context=context, - ) - assert await client.ping() == b"PONG" + ) as client: + assert await client.ping() == b"PONG" async def test_invalid_ssl_parameters(self, redis_ssl_server): context = ssl.create_default_context() @@ -225,37 +225,48 @@ async def test_invalid_ssl_parameters(self, redis_ssl_server): certfile="./tests/tls/invalid-client.crt", keyfile="./tests/tls/invalid-client.key", ) - client = coredis.Redis( - port=8379, - ssl_context=context, - ) - with pytest.raises(ConnectionError, match="decrypt error") as exc_info: - await client.ping() - assert isinstance(exc_info.value.__cause__, SSLError) + + with pytest.RaisesGroup( + pytest.RaisesExc(ssl.SSLError, match=re.escape("decrypt error")), flatten_subgroups=True + ): + async with coredis.Redis( + port=8379, + ssl_context=context, + ): + pass async def test_ssl_no_verify_client(self, redis_ssl_server_no_client_auth): - client = coredis.Redis(port=7379, ssl=True, ssl_cert_reqs="required") - with pytest.raises(ConnectionError, match="certificate verify failed"): - await client.ping() - client = coredis.Redis(port=7379, ssl=True, ssl_cert_reqs="none") - assert await client.ping() == b"PONG" + with pytest.RaisesGroup( + pytest.RaisesExc( + ssl.SSLCertVerificationError, match=re.escape("certificate verify failed") + ), + flatten_subgroups=True, + ): + async with coredis.Redis(port=7379, ssl=True, ssl_cert_reqs="required") as client: + await client.ping() + async with coredis.Redis(port=7379, ssl=True, ssl_cert_reqs="none") as client: + assert await client.ping() == b"PONG" class TestFromUrl: async def test_basic_client(self, redis_basic_server): - client = coredis.Redis.from_url(f"redis://{redis_basic_server[0]}:{redis_basic_server[1]}") - assert b"PONG" == await client.ping() - client = coredis.Redis.from_url( + async with coredis.Redis.from_url( + f"redis://{redis_basic_server[0]}:{redis_basic_server[1]}" + ) as client: + assert b"PONG" == await client.ping() + async with coredis.Redis.from_url( f"redis://{redis_basic_server[0]}:{redis_basic_server[1]}", decode_responses=True, - ) - assert "PONG" == await client.ping() + ) as client: + assert "PONG" == await client.ping() async def test_uds_client(self, redis_uds_server): - client = coredis.Redis.from_url(f"redis://{redis_uds_server}") - assert b"PONG" == await client.ping() - client = coredis.Redis.from_url(f"redis://{redis_uds_server}", decode_responses=True) - assert "PONG" == await client.ping() + async with coredis.Redis.from_url(f"unix://{redis_uds_server}") as client: + assert b"PONG" == await client.ping() + async with coredis.Redis.from_url( + f"unix://{redis_uds_server}", decode_responses=True + ) as client: + assert "PONG" == await client.ping() @pytest.mark.parametrize( "cert_reqs", @@ -276,21 +287,21 @@ async def test_ssl_client(self, redis_ssl_server, cert_reqs): ) if cert_reqs is not None: storage_url += f"&ssl_cert_reqs={cert_reqs}" - client = coredis.Redis.from_url(storage_url) - assert b"PONG" == await client.ping() - client = coredis.Redis.from_url(storage_url, decode_responses=True) - assert "PONG" == await client.ping() + async with coredis.Redis.from_url(storage_url) as client: + assert b"PONG" == await client.ping() + async with coredis.Redis.from_url(storage_url, decode_responses=True) as client: + assert "PONG" == await client.ping() async def test_cluster_client(self, redis_cluster_server): - client = coredis.RedisCluster.from_url( + async with coredis.RedisCluster.from_url( f"redis://{redis_cluster_server[0]}:{redis_cluster_server[1]}" - ) - assert b"PONG" == await client.ping() - client = coredis.RedisCluster.from_url( + ) as client: + assert b"PONG" == await client.ping() + async with coredis.RedisCluster.from_url( f"redis://{redis_cluster_server[0]}:{redis_cluster_server[1]}", decode_responses=True, - ) - assert "PONG" == await client.ping() + ) as client: + assert "PONG" == await client.ping() @pytest.mark.parametrize( "cert_reqs", @@ -311,7 +322,7 @@ async def test_cluster_ssl_client(self, redis_ssl_cluster_server, cert_reqs): ) if cert_reqs is not None: storage_url += f"&ssl_cert_reqs={cert_reqs}" - client = coredis.RedisCluster.from_url(storage_url) - assert b"PONG" == await client.ping() - client = coredis.RedisCluster.from_url(storage_url, decode_responses=True) - assert "PONG" == await client.ping() + async with coredis.RedisCluster.from_url(storage_url) as client: + assert b"PONG" == await client.ping() + async with coredis.RedisCluster.from_url(storage_url, decode_responses=True) as client: + assert "PONG" == await client.ping() diff --git a/tests/test_connection.py b/tests/test_connection.py index 1c0aea73d..24151d629 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,67 +1,73 @@ from __future__ import annotations -import asyncio import socket import pytest +from anyio import create_task_group +from anyio.abc import SocketAttribute from coredis import Connection, UnixDomainSocketConnection from coredis.credentials import UserPassCredentialProvider from coredis.exceptions import TimeoutError -pytest_marks = pytest.mark.asyncio - async def test_connect_tcp(redis_basic): conn = Connection() assert conn.host == "127.0.0.1" assert conn.port == 6379 assert str(conn) == "Connection" - request = await conn.create_request(b"PING") - res = await request - assert res == b"PONG" - assert conn._transport is not None - conn.disconnect() - assert conn._transport is None + async with create_task_group() as tg: + await tg.start(conn.run) + request = await conn.create_request(b"PING") + res = await request + assert res == b"PONG" + assert conn._connection is not None + tg.cancel_scope.cancel() -async def test_connect_cred_provider(redis_auth_cred_provider): +@pytest.mark.xfail +async def test_connect_cred_provider(redis_auth_server): conn = Connection( credential_provider=UserPassCredentialProvider(password="sekret"), host="localhost", port=6389, ) - assert conn.host == "localhost" - assert conn.port == 6389 - assert str(conn) == "Connection" - request = await conn.create_request(b"PING") - res = await request - assert res == b"PONG" - assert conn._transport is not None - conn.disconnect() - assert conn._transport is None + async with create_task_group() as tg: + await tg.start(conn.run) + request = await conn.create_request(b"PING") + res = await request + assert res == b"PONG" + tg.cancel_scope.cancel() @pytest.mark.os("linux") async def test_connect_tcp_keepalive_options(redis_basic): conn = Connection( socket_keepalive=True, - socket_keepalive_options={ - socket.TCP_KEEPIDLE: 1, - socket.TCP_KEEPINTVL: 1, - socket.TCP_KEEPCNT: 3, - }, + socket_keepalive_options={socket.TCP_KEEPINTVL: 1, socket.TCP_KEEPCNT: 3}, + ) + async with create_task_group() as tg: + await tg.start(conn.run) + sock = conn.connection.extra(SocketAttribute.raw_socket) + assert sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) == 1 + for k, v in ((socket.TCP_KEEPINTVL, 1), (socket.TCP_KEEPCNT, 3)): + assert sock.getsockopt(socket.SOL_TCP, k) == v + tg.cancel_scope.cancel() + + +@pytest.mark.os("darwin") +async def test_connect_tcp_keepalive_options_mac(redis_basic): + conn = Connection( + socket_keepalive=True, + socket_keepalive_options={socket.TCP_KEEPINTVL: 1, socket.TCP_KEEPCNT: 3}, ) - await conn._connect() - sock = conn._transport.get_extra_info("socket") - assert sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) == 1 - for k, v in ( - (socket.TCP_KEEPIDLE, 1), - (socket.TCP_KEEPINTVL, 1), - (socket.TCP_KEEPCNT, 3), - ): - assert sock.getsockopt(socket.SOL_TCP, k) == v - conn.disconnect() + async with create_task_group() as tg: + await tg.start(conn.run) + sock = conn.connection.extra(SocketAttribute.raw_socket) + assert sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) == 8 + for k, v in ((socket.TCP_KEEPINTVL, 1), (socket.TCP_KEEPCNT, 3)): + assert sock.getsockopt(socket.SOL_TCP, k) == v + tg.cancel_scope.cancel() @pytest.mark.parametrize("option", ["UNKNOWN", 999]) @@ -69,46 +75,28 @@ async def test_connect_tcp_wrong_socket_opt_raises(option, redis_basic): conn = Connection(socket_keepalive=True, socket_keepalive_options={option: 1}) with pytest.raises((socket.error, TypeError)): await conn._connect() - # verify that the connection isn't left open - assert conn._transport.is_closing() # only test during dev async def test_connect_unix_socket(redis_uds): path = "/tmp/coredis.redis.sock" conn = UnixDomainSocketConnection(path) - await conn.connect() - assert conn.path == path - assert str(conn) == f"UnixDomainSocketConnection" - req = await conn.create_request(b"PING") - res = await req - assert res == b"PONG" - assert conn._transport is not None - conn.disconnect() - assert conn._transport is None + async with create_task_group() as tg: + await tg.start(conn.run) + assert conn.path == path + assert str(conn) == f"UnixDomainSocketConnection" + req = await conn.create_request(b"PING") + res = await req + assert res == b"PONG" + assert conn._connection is not None + tg.cancel_scope.cancel() async def test_stream_timeout(redis_basic): conn = Connection(stream_timeout=0.01) - await conn.connect() is None - req = await conn.create_request(b"debug", "sleep", 0.05) - with pytest.raises(TimeoutError): - await req - - -async def test_lag(redis_basic): - connection = await redis_basic.connection_pool.get_connection(b"ping") - assert connection.lag == 0 - ping_request = await connection.create_request(b"ping") - assert connection.lag != 0 - await ping_request - assert connection.lag == 0 - - -async def test_estimated_time_to_idle(redis_basic): - connection = await redis_basic.connection_pool.get_connection(b"ping") - assert connection.estimated_time_to_idle == 0 - requests = [await connection.create_request(b"ping") for _ in range(10)] - assert connection.estimated_time_to_idle > 0 - await asyncio.gather(*requests) - assert connection.estimated_time_to_idle == 0 + async with create_task_group() as tg: + await tg.start(conn.run) + req = await conn.create_request(b"debug", "sleep", 0.05) + with pytest.raises(TimeoutError): + await req + tg.cancel_scope.cancel() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py deleted file mode 100644 index 2817d56b6..000000000 --- a/tests/test_connection_pool.py +++ /dev/null @@ -1,627 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import re -import ssl -from collections import deque - -import pytest - -import coredis -from coredis._utils import query_param_to_bool -from coredis.exceptions import ( - ConnectionError, - RedisError, -) - - -class DummyConnection: - description = "DummyConnection<>" - - def __init__(self, **kwargs): - self.kwargs = kwargs - self.pid = os.getpid() - self.awaiting_response = False - self.is_connected = False - self.needs_handshake = True - self._last_error = None - self._requests = deque() - self.average_response_time = 0.0 - self.lag = 0.0 - self.requests_pending = 0 - self.requests_processed = 0 - self.estimated_time_to_idle = 0 - self.latency = 0 - - async def connect(self): - self.is_connected = True - - def disconnect(self): - self.is_connected = False - self._last_error = None - - async def perform_handshake(self) -> None: - self.needs_handshake = False - - -@pytest.fixture(autouse=True) -def setup(redis_basic): - pass - - -class TestConnectionPool: - def get_pool( - self, - connection_kwargs=None, - max_connections=None, - connection_class=DummyConnection, - ): - connection_kwargs = connection_kwargs or {} - pool = coredis.ConnectionPool( - connection_class=connection_class, - max_connections=max_connections, - **connection_kwargs, - ) - - return pool - - async def test_connection_creation(self): - connection_kwargs = {"foo": "bar", "biz": "baz"} - pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = await pool.get_connection() - assert isinstance(connection, DummyConnection) - assert connection.kwargs == connection_kwargs - - async def test_multiple_connections(self): - pool = self.get_pool() - c1 = await pool.get_connection() - c2 = await pool.get_connection() - assert c1 != c2 - - async def test_max_connections(self): - pool = self.get_pool(max_connections=2) - await pool.get_connection() - await pool.get_connection() - with pytest.raises(ConnectionError): - await pool.get_connection() - - async def test_pool_disconnect(self): - pool = self.get_pool(max_connections=3) - c1 = await pool.get_connection() - c2 = await pool.get_connection() - c3 = await pool.get_connection() - pool.release(c3) - pool.disconnect() - assert not c1.is_connected - assert not c2.is_connected - assert not c3.is_connected - - async def test_reuse_previously_released_connection(self): - pool = self.get_pool() - c1 = await pool.get_connection() - await c1.connect() - pool.release(c1) - c2 = await pool.get_connection() - assert c1 == c2 - - def test_repr_contains_db_info_tcp(self): - connection_kwargs = {"host": "localhost", "port": 6379, "db": 1} - pool = self.get_pool( - connection_kwargs=connection_kwargs, connection_class=coredis.Connection - ) - expected = "ConnectionPool>" - assert repr(pool) == expected - - def test_repr_contains_db_info_unix(self): - connection_kwargs = {"path": "/abc", "db": 1} - pool = self.get_pool( - connection_kwargs=connection_kwargs, - connection_class=coredis.UnixDomainSocketConnection, - ) - expected = "ConnectionPool>" - assert repr(pool) == expected - - @pytest.mark.xfail - async def test_connection_idle_check(self): - rs = coredis.Redis( - host="127.0.0.1", - port=6379, - db=0, - max_idle_time=0.2, - idle_check_interval=0.1, - ) - await rs.info() - assert len(rs.connection_pool._available_connections) == 1 - assert len(rs.connection_pool._in_use_connections) == 0 - conn = rs.connection_pool._available_connections[0] - last_active_at = conn.last_active_at - await asyncio.sleep(0.3) - assert len(rs.connection_pool._available_connections) == 0 - assert len(rs.connection_pool._in_use_connections) == 0 - assert last_active_at == conn.last_active_at - assert conn._transport is None - - -class TestBlockingConnectionPool: - def get_pool( - self, - connection_kwargs=None, - max_connections=None, - connection_class=DummyConnection, - timeout=None, - ): - connection_kwargs = connection_kwargs or {} - pool = coredis.BlockingConnectionPool( - connection_class=connection_class, - max_connections=max_connections, - timeout=timeout, - **connection_kwargs, - ) - - return pool - - async def test_connection_creation(self): - connection_kwargs = {"foo": "bar", "biz": "baz"} - pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = await pool.get_connection() - assert isinstance(connection, DummyConnection) - assert connection.kwargs == connection_kwargs - - async def test_multiple_connections(self): - pool = self.get_pool() - c1 = await pool.get_connection() - c2 = await pool.get_connection() - assert c1 != c2 - - async def test_max_connections_timeout(self): - pool = self.get_pool(max_connections=2, timeout=0.1) - await pool.get_connection() - await pool.get_connection() - with pytest.raises(ConnectionError): - await pool.get_connection() - - async def test_max_connections_no_timeout(self): - pool = self.get_pool(max_connections=2) - await pool.get_connection() - released_conn = await pool.get_connection() - - def releaser(): - pool.release(released_conn) - - loop = asyncio.get_running_loop() - loop.call_later(0.2, releaser) - new_conn = await pool.get_connection() - assert new_conn == released_conn - - async def test_reuse_previously_released_connection(self): - pool = self.get_pool() - c1 = await pool.get_connection() - pool.release(c1) - c2 = await pool.get_connection() - assert c1 == c2 - - async def test_pool_disconnect(self): - pool = self.get_pool() - c1 = await pool.get_connection() - c2 = await pool.get_connection() - c3 = await pool.get_connection() - pool.release(c3) - pool.disconnect() - assert not c1.is_connected - assert not c2.is_connected - assert not c3.is_connected - - def test_repr_contains_db_info_tcp(self): - connection_kwargs = {"host": "localhost", "port": 6379, "db": 1} - pool = self.get_pool( - connection_kwargs=connection_kwargs, connection_class=coredis.Connection - ) - expected = "BlockingConnectionPool>" - assert repr(pool) == expected - - def test_repr_contains_db_info_unix(self): - connection_kwargs = {"path": "/abc", "db": 1} - pool = self.get_pool( - connection_kwargs=connection_kwargs, - connection_class=coredis.UnixDomainSocketConnection, - ) - expected = "BlockingConnectionPool>" - assert repr(pool) == expected - - @pytest.mark.xfail - async def test_connection_idle_check(self): - rs = coredis.Redis( - host="127.0.0.1", - port=6379, - db=0, - connection_pool=coredis.BlockingConnectionPool( - max_idle_time=0.2, idle_check_interval=0.1, host="127.0.01", port=6379 - ), - ) - await rs.info() - assert len(rs.connection_pool._in_use_connections) == 0 - conn = await rs.connection_pool.get_connection() - last_active_at = conn.last_active_at - rs.connection_pool.release(conn) - await asyncio.sleep(0.3) - assert len(rs.connection_pool._in_use_connections) == 0 - assert last_active_at == conn.last_active_at - assert conn._transport is None - new_conn = await rs.connection_pool.get_connection() - assert conn == new_conn - - -class TestConnectionPoolURLParsing: - def test_defaults(self): - pool = coredis.ConnectionPool.from_url("redis://localhost") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": None, - "password": None, - } - - def test_hostname(self): - pool = coredis.ConnectionPool.from_url("redis://myhost") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "myhost", - "port": 6379, - "db": 0, - "username": None, - "password": None, - } - - def test_quoted_hostname(self): - pool = coredis.ConnectionPool.from_url( - "redis://my %2F host %2B%3D+", decode_components=True - ) - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "my / host +=+", - "port": 6379, - "db": 0, - "username": None, - "password": None, - } - - def test_port(self): - pool = coredis.ConnectionPool.from_url("redis://localhost:6380") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6380, - "db": 0, - "username": None, - "password": None, - } - - def test_password(self): - pool = coredis.ConnectionPool.from_url("redis://:mypassword@localhost") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": "", - "password": "mypassword", - } - - def test_quoted_password(self): - pool = coredis.ConnectionPool.from_url( - "redis://:%2Fmypass%2F%2B word%3D%24+@localhost", decode_components=True - ) - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": None, - "password": "/mypass/+ word=$+", - } - - def test_db_as_argument(self): - pool = coredis.ConnectionPool.from_url("redis://localhost", db="1") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 1, - "username": None, - "password": None, - } - - def test_db_in_path(self): - pool = coredis.ConnectionPool.from_url("redis://localhost/2", db="1") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 2, - "username": None, - "password": None, - } - - def test_db_in_querystring(self): - pool = coredis.ConnectionPool.from_url("redis://localhost/2?db=3", db="1") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 3, - "username": None, - "password": None, - } - - def test_extra_typed_querystring_options(self): - pool = coredis.ConnectionPool.from_url( - "redis://localhost/2?stream_timeout=20&connect_timeout=10" - ) - - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 2, - "stream_timeout": 20.0, - "connect_timeout": 10.0, - "username": None, - "password": None, - } - - def test_boolean_parsing(self): - for expected, value in ( - (None, None), - (None, ""), - (False, 0), - (False, "0"), - (False, "f"), - (False, "F"), - (False, "False"), - (False, "n"), - (False, "N"), - (False, "No"), - (True, 1), - (True, "1"), - (True, "y"), - (True, "Y"), - (True, "Yes"), - ): - assert expected is query_param_to_bool(value) - - def test_invalid_extra_typed_querystring_options(self): - import warnings - - with warnings.catch_warnings(record=True) as warning_log: - coredis.ConnectionPool.from_url( - "redis://localhost/2?stream_timeout=_&connect_timeout=abc" - ) - # Compare the message values - assert [str(m.message) for m in sorted(warning_log, key=lambda log: str(log.message))] == [ - "Invalid value for `connect_timeout` in connection URL.", - "Invalid value for `stream_timeout` in connection URL.", - ] - - def test_max_connections_querystring_option(self): - pool = coredis.ConnectionPool.from_url("redis://localhost?max_connections=32") - assert pool.max_connections == 32 - - def test_max_idle_times_querystring_option(self): - pool = coredis.ConnectionPool.from_url("redis://localhost?max_idle_time=5") - assert pool.max_idle_time == 5 - - def test_idle_check_interval_querystring_option(self): - pool = coredis.ConnectionPool.from_url("redis://localhost?idle_check_interval=1") - assert pool.idle_check_interval == 1 - - def test_extra_querystring_options(self): - pool = coredis.ConnectionPool.from_url("redis://localhost?a=1&b=2") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": None, - "password": None, - "a": "1", - "b": "2", - } - - def test_client_creates_connection_pool(self): - r = coredis.Redis.from_url("redis://myhost") - assert r.connection_pool.connection_class == coredis.Connection - assert r.connection_pool.connection_kwargs == { - "host": "myhost", - "port": 6379, - "db": 0, - "decode_responses": False, - "protocol_version": 3, - "username": None, - "password": None, - "noreply": False, - "noevict": False, - "notouch": False, - } - - -class TestConnectionPoolUnixSocketURLParsing: - def test_defaults(self): - pool = coredis.ConnectionPool.from_url("unix:///socket") - assert pool.connection_class == coredis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 0, - "username": None, - "password": None, - } - - def test_password(self): - pool = coredis.ConnectionPool.from_url("unix://:mypassword@/socket") - assert pool.connection_class == coredis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 0, - "username": "", - "password": "mypassword", - } - - def test_quoted_password(self): - pool = coredis.ConnectionPool.from_url( - "unix://:%2Fmypass%2F%2B word%3D%24+@/socket", decode_components=True - ) - assert pool.connection_class == coredis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 0, - "username": None, - "password": "/mypass/+ word=$+", - } - - def test_quoted_path(self): - pool = coredis.ConnectionPool.from_url( - "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket", - decode_components=True, - ) - assert pool.connection_class == coredis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/my/path/to/../+_+=$ocket", - "db": 0, - "username": None, - "password": "mypassword", - } - - def test_db_as_argument(self): - pool = coredis.ConnectionPool.from_url("unix:///socket", db=1) - assert pool.connection_class == coredis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 1, - "username": None, - "password": None, - } - - def test_db_in_querystring(self): - pool = coredis.ConnectionPool.from_url("unix:///socket?db=2", db=1) - assert pool.connection_class == coredis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 2, - "username": None, - "password": None, - } - - def test_max_connections_querystring_option(self): - pool = coredis.ConnectionPool.from_url("unix:///localhost?max_connections=32") - assert pool.max_connections == 32 - - def test_max_idle_times_querystring_option(self): - pool = coredis.ConnectionPool.from_url("unix:///localhost?max_idle_time=5") - assert pool.max_idle_time == 5 - - def test_idle_check_interval_querystring_option(self): - pool = coredis.ConnectionPool.from_url("unix:///localhost?idle_check_interval=1") - assert pool.idle_check_interval == 1 - - def test_extra_querystring_options(self): - pool = coredis.ConnectionPool.from_url("unix:///socket?a=1&b=2") - assert pool.connection_class == coredis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 0, - "username": None, - "password": None, - "a": "1", - "b": "2", - } - - -class TestSSLConnectionURLParsing: - def test_defaults(self): - pool = coredis.ConnectionPool.from_url("rediss://localhost") - assert pool.connection_class == coredis.Connection - assert pool.connection_kwargs.pop("ssl_context") is not None - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6379, - "db": 0, - "username": None, - "password": None, - } - - @pytest.mark.parametrize( - "query_param, expected", - [ - ( - "none", - ssl.CERT_NONE, - ), - ( - "optional", - ssl.CERT_OPTIONAL, - ), - ("required", ssl.CERT_REQUIRED), - (None, ssl.CERT_OPTIONAL), - ], - ) - async def test_cert_reqs_options(self, query_param, expected): - uri = "rediss://?ssl_keyfile=./tests/tls/client.key&ssl_certfile=./tests/tls/client.crt" - if query_param: - uri += f"&ssl_cert_reqs={query_param}" - pool = coredis.ConnectionPool.from_url(uri) - assert (await pool.get_connection()).ssl_context.verify_mode == expected - - -class TestConnection: - async def test_on_connect_error(self): - """ - An error in Connection.on_connect should disconnect from the server - see for details: https://github.com/andymccurdy/redis-py/issues/368 - """ - # this assumes the Redis server being tested against doesn't have - # 9999 databases ;) - bad_connection = coredis.Redis(db=9999) - # an error should be raised on connect - with pytest.raises(RedisError): - await bad_connection.info() - pool = bad_connection.connection_pool - assert not pool._available_connections[0].is_connected - - async def test_busy_loading_from_pipeline(self): - """ - BusyLoadingErrors should be raised from a pipeline execution - regardless of the raise_on_error flag. - """ - client = coredis.Redis() - pipe = await client.pipeline() - await pipe.create_request( - b"DEBUG", b"ERROR", b"LOADING fake message", callback=lambda r, **k: r - ) - with pytest.raises(RedisError): - await pipe.execute() - pool = client.connection_pool - assert not pipe.connection - assert len(pool._available_connections) == 1 - assert pool._available_connections[0]._transport - - def test_connect_from_url_tcp(self): - connection = coredis.Redis.from_url("redis://localhost") - pool = connection.connection_pool - - assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( - "ConnectionPool", - "Connection", - "host=localhost,port=6379,db=0", - ) - - def test_connect_from_url_unix(self): - connection = coredis.Redis.from_url("unix:///path/to/socket") - pool = connection.connection_pool - - assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( - "ConnectionPool", - "UnixDomainSocketConnection", - "path=/path/to/socket,db=0", - ) diff --git a/tests/test_encoding.py b/tests/test_encoding.py index bafbaecf7..56553012e 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -10,8 +10,9 @@ @pytest.fixture async def redis_no_decode(redis_basic_server): client = coredis.Redis() - await client.flushdb() - return client + async with client: + await client.flushdb() + yield client class TestEncoding: diff --git a/tests/test_lru_cache.py b/tests/test_lru_cache.py index e13bd1878..6d87e5497 100644 --- a/tests/test_lru_cache.py +++ b/tests/test_lru_cache.py @@ -3,22 +3,13 @@ import pytest from coredis.cache import LRUCache +from coredis.commands.constants import CommandName class TestLRUCache: def test_max_keys(self): - cache = LRUCache(max_items=1) - cache.insert("a", 1) - cache.insert("b", 1) + cache = LRUCache(max_keys=1) + cache.put(CommandName.GET, "a", value="1") + cache.put(CommandName.GET, "b", value="1") with pytest.raises(KeyError): - cache.get("a") - - @pytest.mark.nopypy - def test_max_bytes(self): - cache = LRUCache(max_bytes=500) - cache.insert("a", bytearray(400)) - cache.insert("b", bytearray(50)) - cache.shrink() - cache.get("b") - with pytest.raises(KeyError): - cache.get("a") + cache.get(CommandName.GET, "a") diff --git a/tests/test_monitor.py b/tests/test_monitor.py deleted file mode 100644 index 08daac243..000000000 --- a/tests/test_monitor.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import asyncio - -from tests.conftest import targets - - -@targets("redis_basic", "redis_basic_blocking") -class TestMonitor: - async def test_explicit_fetch(self, client, cloner): - monitored = await cloner(client) - await monitored.ping() - async with await client.monitor() as monitor: - response = await asyncio.gather(monitor.get_command(), monitored.get("test")) - assert response[0].command == "GET" - response = await asyncio.gather(monitor.get_command(), monitored.get("test2")) - assert response[0].command == "GET" - assert not monitor.monitoring - - async def test_iterator(self, client): - async def delayed(): - await asyncio.sleep(0.1) - return await client.get("test") - - async def collect(): - results = [] - async for command in client.monitor(): - results.append(command) - break - return results - - results = await asyncio.gather(delayed(), collect()) - assert results[1][0].command in ["HELLO", "GET"] - - async def test_monitor_request_handler(self, client, mocker): - cmds = set() - - monitor = await client.monitor(lambda cmd: cmds.add(cmd.command)) - await asyncio.sleep(0.01) - await client.ping() - await asyncio.sleep(0.01) - await monitor.aclose() - assert "PING" in cmds - await asyncio.sleep(0.01) - await client.get("test") - await asyncio.sleep(0.01) - assert "GET" not in cmds diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 6f21a079d..cd836f562 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -1,8 +1,10 @@ from __future__ import annotations +import math + import pytest +from anyio import create_memory_object_stream -from coredis import BaseConnection from coredis._utils import b from coredis.exceptions import ( ConnectionError, @@ -13,26 +15,14 @@ from coredis.parser import NOT_ENOUGH_DATA, Parser -class DummyConnection(BaseConnection): - def __init__(self, *a, **k): - super().__init__(*a, **k) - - def data_received(self, data): - self._parser.feed(data) - - async def _connect(self) -> None: - pass - - @pytest.fixture -def connection(request): - return DummyConnection(decode_responses=request.getfixturevalue("decode")) +def object_stream(request): + return create_memory_object_stream(math.inf) @pytest.fixture -def parser(connection): - parser = Parser() - parser.on_connect(connection) +def parser(object_stream): + parser = Parser(object_stream[0]) return parser @@ -254,16 +244,14 @@ def test_nested_array(self, parser, decode): ] def test_simple_push_array(self, parser, decode): - parser.feed(b">2\r\n$2\r\nco\r\n$5\r\nredis\r\n") - assert parser.get_response( - decode=decode, encoding="latin-1", push_message_types={b"co"} - ) == [ - self.encoded_value(decode, b"co"), + parser.feed(b">2\r\n$7\r\nmessage\r\n$5\r\nredis\r\n") + parser.get_response(decode=decode, encoding="latin-1") == [ + self.encoded_value(decode, b"message"), self.encoded_value(decode, b"redis"), ] - def test_interleaved_simple_push_array(self, parser, decode): - parser.feed(b":3\r\n>2\r\n:1\r\n:2\r\n:4\r\n") + def test_interleaved_simple_push_array(self, parser, decode, object_stream): + parser.feed(b":3\r\n>2\r\n$7\r\nmessage\r\n$5\r\nredis\r\n:4\r\n") assert ( parser.get_response( decode=decode, @@ -278,7 +266,10 @@ def test_interleaved_simple_push_array(self, parser, decode): ) == 4 ) - assert parser.push_messages.get_nowait() == [1, 2] + assert object_stream[1].receive_nowait() == [ + self.encoded_value(decode, b"message"), + self.encoded_value(decode, b"redis"), + ] def test_nil_map(self, parser, decode): parser.feed(b"%-1\r\n") diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 3dfbfeaf8..e55888ffc 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,10 +1,12 @@ from __future__ import annotations -import asyncio from decimal import Decimal import pytest +from coredis._concurrency import gather +from coredis.client.basic import Redis +from coredis.commands.request import CommandRequest from coredis.exceptions import ( AuthorizationError, RedisError, @@ -12,53 +14,47 @@ TimeoutError, WatchError, ) +from coredis.pipeline import Pipeline from coredis.typing import Serializable from tests.conftest import targets -@targets( - "redis_basic", - "redis_basic_blocking", - "dragonfly", - "valkey", - "redict", -) +@targets("redis_basic", "dragonfly", "valkey", "redict") class TestPipeline: async def test_empty_pipeline(self, client): - async with await client.pipeline() as pipe: - assert await pipe.execute() == () + async with client.pipeline(): + pass + + async def test_pipeline(self, client: Redis[str]): + async with client.pipeline() as pipe: + a = pipe.set("a", "a1") + b = pipe.get("a") + c = pipe.zadd("z", {"z1": 1}) + d = pipe.zadd("z", {"z2": 4}) + e = pipe.zincrby("z", "z1", 1) + f = pipe.zrange("z", 0, 5, withscores=True) + assert await gather(a, b, c, d, e, f) == ( + True, + "a1", + 1, + 1, + 2.0, + (("z1", 2.0), ("z2", 4)), + ) - async def test_pipeline(self, client): - async with await client.pipeline() as pipe: - pipe.set("a", "a1") - pipe.get("a") - pipe.zadd("z", dict(z1=1)) - pipe.zadd("z", dict(z2=4)) - pipe.zincrby("z", "z1", 1) - pipe.zrange("z", 0, 5, withscores=True) - assert await pipe.execute() == ( - True, - "a1", - 1, - 1, - 2.0, - (("z1", 2.0), ("z2", 4)), - ) - - async def test_pipeline_transforms(self, client, _s): + async def test_pipeline_transforms(self, client): client.type_adapter.register( Decimal, lambda v: str(v), lambda v: Decimal(v if isinstance(v, str) else v.decode("utf-8")), ) - pipe = await client.pipeline() - pipe.set("a", Serializable(Decimal(1.23))) - r = pipe.get("a").transform(Decimal) - assert (True, _s(str(Decimal(1.23)))) == await pipe.execute() - assert Decimal(1.23) == await r + async with client.pipeline() as pipe: + a = pipe.set("a", Serializable(Decimal(1.23))) + b = pipe.get("a").transform(Decimal) + assert (True, Decimal(1.23)) == await gather(a, b) async def test_pipeline_length(self, client): - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: # Initially empty. assert len(pipe) == 0 assert pipe @@ -70,208 +66,168 @@ async def test_pipeline_length(self, client): assert len(pipe) == 3 assert pipe - # Execute calls reset(), so empty once again. - await pipe.execute() - assert len(pipe) == 0 - assert pipe + # Execute calls reset(), so empty once again. + assert len(pipe) == 0 + assert pipe async def test_pipeline_no_transaction(self, client): - async with await client.pipeline(transaction=False) as pipe: - pipe.set("a", "a1") - pipe.set("b", "b1") - pipe.set("c", "c1") - assert await pipe.execute() == (True, True, True) - assert await client.get("a") == "a1" - assert await client.get("b") == "b1" - assert await client.get("c") == "c1" + async with client.pipeline(transaction=False) as pipe: + a = pipe.set("a", "a1") + b = pipe.set("b", "b1") + c = pipe.set("c", "c1") + assert await gather(a, b, c) == (True, True, True) + assert await client.get("a") == "a1" + assert await client.get("b") == "b1" + assert await client.get("c") == "c1" async def test_pipeline_invalid_flow(self, client): - pipe = await client.pipeline(transaction=False) - pipe.multi() - with pytest.raises(RedisError): + async with client.pipeline(transaction=False) as pipe: pipe.multi() + with pytest.raises(RedisError): + pipe.multi() - pipe = await client.pipeline(transaction=False) - pipe.multi() - with pytest.raises(RedisError): - pipe.watch("test") - - pipe = await client.pipeline(transaction=False) - pipe.set("fubar", 1) - with pytest.raises(RedisError): + async with client.pipeline(transaction=False) as pipe: pipe.multi() + with pytest.raises(RedisError): + await pipe.watch("test") + + async with client.pipeline(transaction=False) as pipe: + pipe.set("fubar", 1) + with pytest.raises(RedisError): + pipe.multi() @pytest.mark.nodragonfly - async def test_pipeline_no_permission(self, client, user_client): + async def test_pipeline_no_permission(self, user_client): no_perm_client = await user_client("testuser", "on", "+@all", "-MULTI") - async with await no_perm_client.pipeline(transaction=False) as pipe: - pipe.multi() - pipe.get("fubar") + async with no_perm_client: with pytest.raises(AuthorizationError): - await pipe.execute() + async with no_perm_client.pipeline(transaction=False) as pipe: + pipe.multi() + pipe.get("fubar") async def test_pipeline_no_transaction_watch(self, client): await client.set("a", "0") - async with await client.pipeline(transaction=False) as pipe: + async with client.pipeline(transaction=False) as pipe: await pipe.watch("a") a = await pipe.get("a") - pipe.multi() - pipe.set("a", str(int(a) + 1)) - assert await pipe.execute() == (True,) + b = pipe.set("a", str(int(a) + 1)) + assert await b async def test_pipeline_no_transaction_watch_failure(self, client): await client.set("a", "0") - async with await client.pipeline(transaction=False) as pipe: - await pipe.watch("a") - a = await pipe.get("a") + with pytest.raises(WatchError): + async with client.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") - await client.set("a", "bad") - - pipe.multi() - pipe.set("a", str(int(a) + 1)) + await client.set("a", "bad") - with pytest.raises(WatchError): - await pipe.execute() + pipe.multi() + pipe.set("a", str(int(a) + 1)) - assert await client.get("a") == "bad" + assert await client.get("a") == "bad" - async def test_exec_error_in_response(self, client): + async def test_exec_error_in_response(self, client: Redis[str]): """ an invalid pipeline command at exec time adds the exception instance to the list of returned values """ await client.set("c", "a") - async with await client.pipeline() as pipe: - pipe.set("a", "1") - pipe.set("b", "2") - pipe.lpush("c", ["3"]) - pipe.set("d", "4") - result = await pipe.execute(raise_on_error=False) - - assert result[0] - assert await client.get("a") == "1" - assert result[1] - assert await client.get("b") == "2" - - # we can't lpush to a key that's a string value, so this should - # be a ResponseError exception - assert isinstance(result[2], ResponseError) - assert await client.get("c") == "a" - - # since this isn't a transaction, the other commands after the - # error are still executed - assert result[3] - assert await client.get("d") == "4" - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" - - async def test_exec_error_in_response_explicit_transaction(self, client): + async with client.pipeline(raise_on_error=False, transaction=False) as pipe: + a = pipe.set("a", "1") + b = pipe.set("b", "2") + c = pipe.lpush("c", ["3"]) + d = pipe.set("d", "4") + + assert await a + assert await client.get("a") == "1" + assert await b + assert await client.get("b") == "2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(await c, ResponseError) + assert await client.get("c") == "a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert await d + assert await client.get("d") == "4" + + async def test_exec_error_in_response_explicit_transaction(self, client: Redis[str]): """ an invalid pipeline command at exec time adds the exception instance to the list of returned values """ await client.set("c", "a") - async with await client.pipeline(transaction=False) as pipe: + async with client.pipeline(raise_on_error=False, transaction=False) as pipe: pipe.multi() - pipe.set("a", "1") - pipe.set("b", "2") - pipe.lpush("c", ["3"]) - pipe.set("d", "4") - result = await pipe.execute(raise_on_error=False) - - assert result[0] - assert await client.get("a") == "1" - assert result[1] - assert await client.get("b") == "2" - - # we can't lpush to a key that's a string value, so this should - # be a ResponseError exception - assert isinstance(result[2], ResponseError) - assert await client.get("c") == "a" - - # since this isn't a transaction, the other commands after the - # error are still executed - assert result[3] - assert await client.get("d") == "4" - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" + a = pipe.set("a", "1") + b = pipe.set("b", "2") + c = pipe.lpush("c", ["3"]) + d = pipe.set("d", "4") + + assert await a + assert await client.get("a") == "1" + assert await b + assert await client.get("b") == "2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(await c, ResponseError) + assert await client.get("c") == "a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert await d + assert await client.get("d") == "4" async def test_exec_error_raised(self, client): await client.set("c", "a") - async with await client.pipeline() as pipe: - pipe.set("a", "1") - pipe.set("b", "2") - pipe.lpush("c", ["3"]) - pipe.set("d", "4") - with pytest.raises(ResponseError): - await pipe.execute() - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" + with pytest.raises(ResponseError): + async with client.pipeline() as pipe: + pipe.set("a", "1") + pipe.set("b", "2") + pipe.lpush("c", ["3"]) + pipe.set("d", "4") async def test_exec_error_raised_explicit_transaction(self, client): await client.set("c", "a") - async with await client.pipeline(transaction=False) as pipe: - pipe.multi() - pipe.set("a", "1") - pipe.set("b", "2") - pipe.lpush("c", ["3"]) - pipe.set("d", "4") - with pytest.raises(ResponseError): - await pipe.execute() - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" + with pytest.raises(ResponseError): + async with client.pipeline(transaction=False) as pipe: + pipe.multi() + pipe.set("a", "1") + pipe.set("b", "2") + pipe.lpush("c", ["3"]) + pipe.set("d", "4") @pytest.mark.nodragonfly - async def test_parse_error_raised(self, client): - async with await client.pipeline() as pipe: - # the zrem is invalid because we don't pass any keys to it - pipe.set("a", "1") - pipe.zrem("b", []) - pipe.set("b", "2") - with pytest.raises(ResponseError): - await pipe.execute() - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" + async def test_parse_error_raised(self, client: Redis[str]): + with pytest.raises(ResponseError): + async with client.pipeline() as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", "1") + pipe.zrem("b", []) + pipe.set("b", "2") @pytest.mark.nodragonfly - async def test_parse_error_raised_explicit_transaction(self, client): - async with await client.pipeline(transaction=False) as pipe: - pipe.multi() - # the zrem is invalid because we don't pass any keys to it - pipe.set("a", "1") - pipe.zrem("b", []) - pipe.set("b", "2") - with pytest.raises(ResponseError): - await pipe.execute() - - # make sure the pipe was restored to a working state - pipe.set("z", "zzz") - assert await pipe.execute() == (True,) - assert await client.get("z") == "zzz" - - async def test_watch_succeed(self, client): + async def test_parse_error_raised_explicit_transaction(self, client: Redis[str]): + with pytest.raises(ResponseError): + async with client.pipeline(transaction=False) as pipe: + pipe.multi() + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", "1") + pipe.zrem("b", []) + pipe.set("b", "2") + + async def test_watch_succeed(self, client: Redis[str]): await client.set("a", "1") await client.set("b", "2") - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: await pipe.watch("a", "b") assert pipe.watching a_value = await pipe.get("a") @@ -280,66 +236,71 @@ async def test_watch_succeed(self, client): assert b_value == "2" pipe.multi() - pipe.set("c", "3") - assert await pipe.execute() == (True,) - assert not pipe.watching + res = pipe.set("c", "3") - async def test_watch_failure(self, client): + assert await res + assert not pipe.watching + + async def test_watch_failure(self, client: Redis[str]): await client.set("a", "1") await client.set("b", "2") - async with await client.pipeline() as pipe: - await pipe.watch("a", "b") - await client.set("b", "3") - pipe.multi() - pipe.get("a") - with pytest.raises(WatchError): - await pipe.execute() - - assert not pipe.watching + with pytest.raises(WatchError): + async with client.pipeline() as pipe: + await pipe.watch("a", "b") + await client.set("b", "3") + pipe.multi() + pipe.get("a") - @pytest.mark.xfail - async def test_pipeline_transaction_with_watch_on_construction(self, client): - pipe = await client.pipeline(transaction=True, watches=["a{fu}"]) - - async def overwrite(): - i = 0 - while True: - try: - await client.set("a{fu}", i) - except asyncio.CancelledError: - break - except Exception: - break - - [pipe.set("a{fu}", -1 * i) for i in range(1000)] - - task = asyncio.create_task(overwrite()) - try: - await asyncio.sleep(0.1) - with pytest.raises(WatchError): - await pipe.execute() - finally: - task.cancel() - - async def test_unwatch(self, client): + async def test_unwatch(self, client: Redis[str]): await client.set("a", "1") await client.set("b", "2") - async with await client.pipeline() as pipe: + async with client.pipeline() as pipe: await pipe.watch("a", "b") await client.set("b", "3") await pipe.unwatch() assert not pipe.watching - pipe.get("a") - assert await pipe.execute() == ("1",) + res = pipe.get("a") + assert await res == "1" + + async def test_exec_error_in_no_transaction_pipeline(self, client: Redis[str]): + await client.set("a", "1") + with pytest.raises(ResponseError): + async with client.pipeline(transaction=False) as pipe: + pipe.llen("a") + pipe.expire("a", 100) + + assert await client.get("a") == "1" + + async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, client: Redis[str]): + key = chr(11) + "abcd" + chr(23) + await client.set(key, "1") + with pytest.raises(ResponseError): + async with client.pipeline(transaction=False) as pipe: + pipe.llen(key) + pipe.expire(key, 100) + + assert await client.get(key) == "1" + + async def test_pipeline_timeout(self, client: Redis[str]): + await client.hset("hash", {str(i): bytes(1024) for i in range(1024)}) + with pytest.raises(TimeoutError): + async with client.pipeline(timeout=0.01) as pipe: + for _ in range(20): + pipe.hgetall("hash") + + async with client.pipeline(timeout=5) as pipe: + for _ in range(20): + pipe.hgetall("hash") - async def test_transaction_callable(self, client): + async def test_transaction_callable(self, client: Redis[str]): await client.set("a", "1") await client.set("b", "2") - has_run = [] + has_run = False - async def my_transaction(pipe): + async def my_transaction(pipe: Pipeline[str]) -> CommandRequest[bool]: + nonlocal has_run a_value = await pipe.get("a") assert a_value in ("1", "2") b_value = await pipe.get("b") @@ -347,52 +308,13 @@ async def my_transaction(pipe): # silly run-once code... incr's "a" so WatchError should be raised # forcing this all to run again. this should incr "a" once to "2" - if not has_run: await client.incr("a") - has_run.append("it has") + has_run = True pipe.multi() - pipe.set("c", str(int(a_value) + int(b_value))) + return pipe.set("c", str(int(a_value) + int(b_value))) result = await client.transaction(my_transaction, "a", "b", watch_delay=0.01) - assert result == (True,) + assert await result assert await client.get("c") == "4" - - async def test_exec_error_in_no_transaction_pipeline(self, client): - await client.set("a", "1") - async with await client.pipeline(transaction=False) as pipe: - pipe.llen("a") - pipe.expire("a", 100) - - with pytest.raises(ResponseError): - await pipe.execute() - - assert await client.get("a") == "1" - - async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, client): - key = chr(11) + "abcd" + chr(23) - await client.set(key, "1") - async with await client.pipeline(transaction=False) as pipe: - pipe.llen(key) - pipe.expire(key, 100) - - with pytest.raises(ResponseError): - await pipe.execute() - - assert await client.get(key) == "1" - - async def test_pipeline_timeout(self, client): - await client.hset("hash", {str(i): i for i in range(4096)}) - await client.ping() - pipeline = await client.pipeline(timeout=0.01) - for i in range(20): - pipeline.hgetall("hash") - with pytest.raises(TimeoutError): - await pipeline.execute() - - await client.ping() - pipeline = await client.pipeline(timeout=5) - for i in range(20): - pipeline.hgetall("hash") - await pipeline.execute() diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index ac0b0839d..cf8c219d6 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1,17 +1,16 @@ from __future__ import annotations -import asyncio import pickle import time -import pytest +import anyio -import coredis -from coredis.exceptions import ConnectionError +from coredis.client.basic import Redis +from coredis.commands.pubsub import PubSub from tests.conftest import targets -async def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False): +async def wait_for_message(pubsub: PubSub, timeout=0.5, ignore_subscribe_messages=False): now = time.time() timeout = now + timeout @@ -22,7 +21,7 @@ async def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False) if message is not None: return message - await asyncio.sleep(0.01) + await anyio.sleep(0.01) now = time.time() return None @@ -69,14 +68,7 @@ def make_subscribe_test_data(pubsub, encoder, type): assert False, f"invalid subscribe type: {type}" -@targets( - "redis_basic", - "redis_basic_blocking", - "redis_basic_raw", - "dragonfly", - "valkey", - "redict", -) +@targets("redis_basic", "redis_basic_raw", "dragonfly", "valkey", "redict") class TestPubSubSubscribeUnsubscribe: async def _test_subscribe_unsubscribe( self, @@ -116,9 +108,10 @@ async def test_pattern_subscribe_unsubscribe(self, client, _s): await self._test_subscribe_unsubscribe(**kwargs) async def _test_resubscribe_on_reconnection( - self, p, encoder, sub_type, unsub_type, sub_func, unsub_func, keys + self, p: PubSub, encoder, sub_type, unsub_type, sub_func, unsub_func, keys ): async with p: + p.connection.max_idle_time = 1 for key in keys: assert await sub_func(key) is None # should be a message for each channel/pattern we just subscribed to @@ -126,8 +119,8 @@ async def _test_resubscribe_on_reconnection( for i, key in enumerate(keys): assert await wait_for_message(p) == make_message(sub_type, encoder(key), i + 1) - # manually disconnect - p.connection.disconnect() + # wait for disconnect + await anyio.sleep(2) # calling get_message again reconnects and resubscribes # note, we may not re-subscribe to channels in exactly the same order # so we have to do some extra checks to make sure we got them all @@ -249,7 +242,7 @@ async def test_ignore_individual_subscribe_messages(self, client): assert message is None assert p.subscribed is False - async def test_subscribe_on_construct(self, client, _s): + async def test_subscribe_on_construct(self, client: Redis, _s): handled = [] def handle(message): @@ -262,7 +255,6 @@ def handle(message): patterns=["baz*"], pattern_handlers={"qu*": handle}, ) as pubsub: - assert pubsub.subscribed await client.publish("foo", "bar") await client.publish("bar", "foo") await client.publish("baz", "qux") @@ -275,7 +267,6 @@ def handle(message): ) assert handled == [_s("foo"), _s("quxx")] - assert not pubsub.subscribed @targets("redis_basic", "redis_basic_raw") @@ -418,7 +409,7 @@ def handler(message): await client.publish("fu", "bar") await client.publish("bar", "fu") - await asyncio.sleep(0.1) + await anyio.sleep(0.1) assert messages == {_s("fu"), _s("bar")} @@ -434,26 +425,17 @@ async def collect(): [messages.append(message) async for message in p] async def unsubscribe(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await p.punsubscribe("fu*") await p.unsubscribe("test") - completed, pending = await asyncio.wait( - [asyncio.create_task(collect()), asyncio.create_task(unsubscribe())], timeout=1 - ) - assert all(task.done() for task in completed) - assert not pending + with anyio.fail_after(1): + async with anyio.create_task_group() as tg: + tg.start_soon(collect) + tg.start_soon(unsubscribe) assert len(messages) == 20 -class TestPubSubRedisDown: - async def test_channel_subscribe(self): - client = coredis.Redis(host="localhost", port=9999) - p = client.pubsub() - with pytest.raises(ConnectionError): - await p.subscribe("foo") - - @targets("redis_basic", "redis_basic_raw") class TestPubSubPubSubSubcommands: async def test_pubsub_channels(self, client, _s): @@ -464,20 +446,15 @@ async def test_pubsub_channels(self, client, _s): async def test_pubsub_numsub(self, client, _s): p1 = client.pubsub(ignore_subscribe_messages=True) - await p1.subscribe("foo", "bar", "baz") p2 = client.pubsub(ignore_subscribe_messages=True) - await p2.subscribe("bar", "baz") p3 = client.pubsub(ignore_subscribe_messages=True) - await p3.subscribe("baz") - - channels = {_s("foo"): 1, _s("bar"): 2, _s("baz"): 3} - assert channels == await client.pubsub_numsub("foo", "bar", "baz") - await p1.unsubscribe() - await p2.unsubscribe() - await p3.unsubscribe() - p1.close() - p2.close() - p3.close() + async with p1, p2, p3: + await p1.subscribe("foo", "bar", "baz") + await p2.subscribe("bar", "baz") + await p3.subscribe("baz") + + channels = {_s("foo"): 1, _s("bar"): 2, _s("baz"): 3} + assert channels == await client.pubsub_numsub("foo", "bar", "baz") async def test_pubsub_numpat(self, client): pubsub_count = await client.pubsub_numpat() diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 54031030c..627a87832 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -4,7 +4,9 @@ from beartype.roar import BeartypeCallHintParamViolation from coredis import PureToken +from coredis._concurrency import gather from coredis.client import Client +from coredis.client.basic import Redis from coredis.commands import Script from coredis.exceptions import NoScriptError, NotBusyError, ResponseError from coredis.typing import AnyStr, KeyT, RedisValueT @@ -52,7 +54,7 @@ async def flush_scripts(client): await client.script_flush() -@targets("redis_basic", "redis_basic_blocking") +@targets("redis_basic") class TestScripting: async def test_eval(self, client): await client.set("a", "2") @@ -86,7 +88,7 @@ async def test_script_flush_sync_mode(self, client): assert await client.script_flush(sync_type=PureToken.SYNC) assert await client.script_exists([sha]) == (False,) - async def test_script_object(self, client): + async def test_script_object(self, client: Redis[str]): await client.set("a", "2") multiply = client.register_script(multiply_script) precalculated_sha = multiply.sha @@ -101,17 +103,17 @@ async def test_script_object(self, client): # Test first evalsha block assert await multiply(keys=["a"], args=[3]) == 6 - async def test_script_object_in_pipeline(self, client): + async def test_script_object_in_pipeline(self, client: Redis[str]): multiply = client.register_script(multiply_script) precalculated_sha = multiply.sha assert precalculated_sha - pipe = await client.pipeline() - pipe.set("a", "2") - pipe.get("a") - multiply(keys=["a"], args=[3], client=pipe) - assert await client.script_exists([multiply.sha]) == (False,) + async with client.pipeline() as pipe: + a = pipe.set("a", "2") + b = pipe.get("a") + c = multiply(keys=["a"], args=[3], client=pipe) + assert await client.script_exists([multiply.sha]) == (False,) # [SET worked, GET 'a', result of multiple script] - assert await pipe.execute() == (True, "2", 6) + assert await gather(a, b, c) == (True, "2", 6) # The script should have been loaded by pipe.execute() assert await client.script_exists([multiply.sha]) == (True,) # The precalculated sha should have been the correct one @@ -120,40 +122,35 @@ async def test_script_object_in_pipeline(self, client): # purge the script from redis's cache and re-run the pipeline # the multiply script should be reloaded by pipe.execute() await client.script_flush() - pipe = await client.pipeline() - pipe.set("a", "2") - pipe.get("a") - multiply(keys=["a"], args=[3], client=pipe) - assert await client.script_exists([multiply.sha]) == (False,) + async with client.pipeline() as pipe: + a = pipe.set("a", "2") + b = pipe.get("a") + c = multiply(keys=["a"], args=[3], client=pipe) + assert await client.script_exists([multiply.sha]) == (False,) # [SET worked, GET 'a', result of multiple script] - assert await pipe.execute() == ( - True, - "2", - 6, - ) + assert await gather(a, b, c) == (True, "2", 6) assert await client.script_exists([multiply.sha]) == (True,) async def testscript_flush_eval_msgpack_pipeline_error_in_lua(self, client): msgpack_hello = client.register_script(msgpack_hello_script) assert msgpack_hello.sha - pipe = await client.pipeline() # avoiding a dependency to msgpack, this is the output of # msgpack.dumps({"name": "joe"}) msgpack_message_1 = b"\x81\xa4name\xa3Joe" + async with client.pipeline() as pipe: + res = msgpack_hello(args=[msgpack_message_1], client=pipe) + assert await client.script_exists([msgpack_hello.sha]) == (False,) - msgpack_hello(args=[msgpack_message_1], client=pipe) - - assert await client.script_exists([msgpack_hello.sha]) == (False,) - assert (await pipe.execute())[0] == "hello Joe" + assert await res == "hello Joe" assert await client.script_exists([msgpack_hello.sha]) == (True,) msgpack_hello_broken = client.register_script(msgpack_hello_script_broken) - msgpack_hello_broken(args=[msgpack_message_1], client=pipe) with pytest.raises(ResponseError) as excinfo: - await pipe.execute() - assert excinfo.type == ResponseError + async with client.pipeline() as pipe: + msgpack_hello_broken(args=[msgpack_message_1], client=pipe) + assert excinfo.type == ResponseError async def test_script_kill_no_scripts(self, client): with pytest.raises(NotBusyError): @@ -208,11 +205,11 @@ async def test_wraps_class_method(self, client): class Wrapper: @classmethod - @scrpt.wraps(key_spec=["key"], client_arg="client", runtime_checks=True) + @scrpt.wraps(client_arg="client", runtime_checks=True) async def default_get( cls, client: Client[AnyStr] | None, - key: str, + key: KeyT, default: str = "coredis", ) -> str: ... diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 244df13c7..409d0f06d 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -1,216 +1,113 @@ from __future__ import annotations +from unittest.mock import AsyncMock + import pytest import coredis from coredis.exceptions import ( - ConnectionError, PrimaryNotFoundError, ReadOnlyError, ReplicaNotFoundError, ReplicationError, - ResponseError, - TimeoutError, ) from coredis.sentinel import Sentinel, SentinelConnectionPool from tests.conftest import targets -pytestmarks = pytest.mark.asyncio - - -class SentinelTestClient: - def __init__(self, cluster, id): - self.cluster = cluster - self.id = id - - async def sentinel_masters(self): - self.cluster.connection_error_if_down(self) - self.cluster.timeout_if_down(self) - - return {self.cluster.service_name: self.cluster.primary} - async def sentinel_replicas(self, primary_name): - self.cluster.connection_error_if_down(self) - self.cluster.timeout_if_down(self) +async def test_init_compose_sentinel(redis_sentinel: Sentinel): + master = redis_sentinel.primary_for("mymaster") + async with master: + await master.ping() - if primary_name != self.cluster.service_name: - return [] - return self.cluster.replicas +async def test_discover_primary(redis_sentinel: Sentinel, host_ip): + address = await redis_sentinel.discover_primary("mymaster") + assert address == (host_ip, 6380) -class SentinelTestCluster: - def __init__(self, service_name="mymaster", ip="127.0.0.1", port=6379): - self.clients = {} - self.primary = { - "ip": ip, - "port": port, +async def test_discover_primary_error(redis_sentinel: Sentinel, mocker): + with pytest.raises(PrimaryNotFoundError): + await redis_sentinel.discover_primary("xxx") + sentinel_masters = mocker.patch.object( + redis_sentinel.sentinels[0], "sentinel_masters", new_callable=AsyncMock + ) + sentinel_masters.return_value = { + "mymaster": { + "ip": "127.0.0.1", + "port": 6380, "is_master": True, - "is_sdown": False, - "is_odown": False, - "num-other-sentinels": 0, + "is_sdown": True, + "is_odown": True, } - self.service_name = service_name - self.replicas = [] - self.nodes_down = set() - self.nodes_timeout = set() - - def connection_error_if_down(self, node): - if node.id in self.nodes_down: - raise ConnectionError - - def timeout_if_down(self, node): - if node.id in self.nodes_timeout: - raise TimeoutError - - def client(self, host, port, **kwargs): - return SentinelTestClient(self, (host, port)) - - -@pytest.fixture() -def cluster(request): - def teardown(): - coredis.sentinel.Redis = saved_Redis - - cluster = SentinelTestCluster() - saved_Redis = coredis.sentinel.Redis - coredis.sentinel.Redis = cluster.client - request.addfinalizer(teardown) - - return cluster - - -@pytest.fixture() -def sentinel(request, cluster): - return Sentinel([("foo", 26379), ("bar", 26379)]) - - -async def test_discover_primary(sentinel): - address = await sentinel.discover_primary("mymaster") - assert address == ("127.0.0.1", 6379) - - -async def test_discover_primary_error(sentinel): - with pytest.raises(PrimaryNotFoundError): - await sentinel.discover_primary("xxx") - - -async def test_discover_primary_sentinel_down(cluster, sentinel): - # Put first sentinel 'foo' down - cluster.nodes_down.add(("foo", 26379)) - address = await sentinel.discover_primary("mymaster") - assert address == ("127.0.0.1", 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ("bar", 26379) - - -async def test_discover_primary_sentinel_timeout(cluster, sentinel): - # Put first sentinel 'foo' down - cluster.nodes_timeout.add(("foo", 26379)) - address = await sentinel.discover_primary("mymaster") - assert address == ("127.0.0.1", 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ("bar", 26379) - - -async def test_master_min_other_sentinels(cluster): - sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1) - # min_other_sentinels - with pytest.raises(PrimaryNotFoundError): - await sentinel.discover_primary("mymaster") - cluster.primary["num-other-sentinels"] = 2 - address = await sentinel.discover_primary("mymaster") - assert address == ("127.0.0.1", 6379) - - -async def test_master_odown(cluster, sentinel): - cluster.primary["is_odown"] = True - with pytest.raises(PrimaryNotFoundError): - await sentinel.discover_primary("mymaster") - - -async def test_master_sdown(cluster, sentinel): - cluster.primary["is_sdown"] = True - with pytest.raises(PrimaryNotFoundError): - await sentinel.discover_primary("mymaster") - + } + with pytest.RaisesGroup(PrimaryNotFoundError): + async with redis_sentinel.primary_for("mymaster") as primary: + await primary.ping() -async def test_discover_replicas(cluster, sentinel): - assert await sentinel.discover_replicas("mymaster") == [] - cluster.replicas = [ - {"ip": "replica0", "port": 1234, "is_odown": False, "is_sdown": False}, - {"ip": "replica1", "port": 1234, "is_odown": False, "is_sdown": False}, - ] - assert await sentinel.discover_replicas("mymaster") == [ - ("replica0", 1234), - ("replica1", 1234), - ] - - # replica0 -> ODOWN - cluster.replicas[0]["is_odown"] = True - assert await sentinel.discover_replicas("mymaster") == [("replica1", 1234)] - - # replica1 -> SDOWN - cluster.replicas[1]["is_sdown"] = True - assert await sentinel.discover_replicas("mymaster") == [] - - cluster.replicas[0]["is_odown"] = False - cluster.replicas[1]["is_sdown"] = False - - # node0 -> DOWN - cluster.nodes_down.add(("foo", 26379)) - assert await sentinel.discover_replicas("mymaster") == [ - ("replica0", 1234), - ("replica1", 1234), - ] - cluster.nodes_down.clear() - - # node0 -> TIMEOUT - cluster.nodes_timeout.add(("foo", 26379)) - assert await sentinel.discover_replicas("mymaster") == [ - ("replica0", 1234), - ("replica1", 1234), - ] - - -async def test_replica_for_slave_not_found_error(cluster, sentinel): - cluster.primary["is_odown"] = True - replica = sentinel.replica_for("mymaster", db=9) - with pytest.raises(ReplicaNotFoundError): - await replica.ping() - - -async def test_replica_round_robin(cluster, sentinel): - cluster.replicas = [ +async def test_replica_for_slave_not_found_error(redis_sentinel: Sentinel, mocker): + sentinel_replicas = mocker.patch.object( + redis_sentinel.sentinels[0], "sentinel_replicas", new_callable=AsyncMock + ) + sentinel_masters = mocker.patch.object( + redis_sentinel.sentinels[0], "sentinel_masters", new_callable=AsyncMock + ) + sentinel_replicas.return_value = [] + sentinel_masters.return_value = {} + replica = redis_sentinel.replica_for("mymaster", db=9) + with pytest.RaisesGroup(ReplicaNotFoundError): + async with replica: + await replica.ping() + + +async def test_replica_round_robin(redis_sentinel: Sentinel, mocker, host_ip): + pool = SentinelConnectionPool("mymaster", redis_sentinel) + sentinel_replicas = mocker.patch.object( + redis_sentinel.sentinels[0], "sentinel_replicas", new_callable=AsyncMock + ) + sentinel_replicas.return_value = [ {"ip": "replica0", "port": 6379, "is_odown": False, "is_sdown": False}, {"ip": "replica1", "port": 6379, "is_odown": False, "is_sdown": False}, ] - pool = SentinelConnectionPool("mymaster", sentinel) - rotator = await pool.rotate_replicas() - assert set(rotator) == {("replica0", 6379), ("replica1", 6379)} + async for rotator in pool.rotate_replicas(): + assert rotator in {("replica0", 6379), ("replica1", 6379)} + sentinel_replicas.return_value = [ + {"ip": "replica0", "port": 6379, "is_odown": False, "is_sdown": False}, + {"ip": "replica1", "port": 6379, "is_odown": False, "is_sdown": True}, + ] + async for rotator in pool.rotate_replicas(): + assert rotator in {("replica0", 6379)} -async def test_autodecode(redis_sentinel_server): +async def test_autodecode(redis_sentinel_server: tuple[str, int]): sentinel = Sentinel(sentinels=[redis_sentinel_server], decode_responses=True) - assert await sentinel.primary_for("mymaster").ping() == "PONG" - assert await sentinel.primary_for("mymaster", decode_responses=False).ping() == b"PONG" + async with sentinel: + client = sentinel.primary_for("mymaster") + async with client: + assert await client.ping() == "PONG" + client = sentinel.primary_for("mymaster", decode_responses=False) + async with client: + assert await client.ping() == b"PONG" -@targets("redis_sentinel", "redis_sentinel_raw", "redis_sentinel_resp2") +@targets("redis_sentinel", "redis_sentinel_raw") class TestSentinelCommand: - async def test_primary_for(self, client, host_ip): + async def test_primary_for(self, client: Sentinel, host_ip): primary = client.primary_for("mymaster") - assert await primary.ping() - assert primary.connection_pool.primary_address == (host_ip, 6380) + async with primary: + assert await primary.ping() + assert primary.connection_pool.primary_address == (host_ip, 6380) # Use internal connection check primary = client.primary_for("mymaster", check_connection=True) - assert await primary.ping() + async with primary: + assert await primary.ping() async def test_replica_for(self, client): replica = client.replica_for("mymaster") - assert await replica.ping() + async with replica: + assert await replica.ping() async def test_ckquorum(self, client): assert await client.sentinels[0].sentinel_ckquorum("mymaster") @@ -237,10 +134,13 @@ async def test_failover(self, client, mocker): async def test_flush_config(self, client): assert await client.sentinels[0].sentinel_flushconfig() - async def test_role(self, client): + async def test_role(self, client: Sentinel): assert (await client.sentinels[0].role()).role == "sentinel" - assert (await client.primary_for("mymaster").role()).role == "master" - assert (await client.replica_for("mymaster").role()).role == "slave" + primary = client.primary_for("mymaster") + replica = client.replica_for("mymaster") + async with primary, replica: + assert (await primary.role()).role == "master" + assert (await replica.role()).role == "slave" async def test_infocache(self, client, _s): assert await client.sentinels[0].sentinel_flushconfig() @@ -259,26 +159,32 @@ async def test_sentinel_replicas(self, client): [k["is_master"] for k in (await client.sentinels[0].sentinel_replicas("mymaster"))] ) - async def test_no_replicas(self, client, mocker): + async def test_no_replicas(self, client: Sentinel, mocker): p = client.replica_for("mymaster") replica_rotate = mocker.patch.object(p.connection_pool, "rotate_replicas") - replica_rotate.return_value = [] - with pytest.raises(ReplicaNotFoundError): - await p.ping() - async def test_write_to_replica(self, client): - p = await client.replica_for("mymaster") - await p.ping() - with pytest.raises(ReadOnlyError): - await p.set("fubar", 1) + async def async_iter(items): + for item in items: + yield item - @pytest.mark.parametrize( - "client_arguments", [{"cache": coredis.cache.TrackingCache(max_size_bytes=-1)}] - ) - async def test_sentinel_cache(self, client, client_arguments, mocker, _s): - await client.primary_for("mymaster").set("fubar", 1) + replica_rotate.return_value = async_iter([]) + with pytest.RaisesGroup(ReplicaNotFoundError, allow_unwrapped=True, flatten_subgroups=True): + async with p: + await p.ping() - assert await client.primary_for("mymaster").get("fubar") == _s("1") + async def test_write_to_replica(self, client): + p = client.replica_for("mymaster") + async with p: + await p.ping() + with pytest.raises(ReadOnlyError): + await p.set("fubar", 1) + + @pytest.mark.parametrize("client_arguments", [{"cache": coredis.cache.LRUCache()}]) + async def test_sentinel_cache(self, client: Sentinel, client_arguments, mocker, _s): + primary = client.primary_for("mymaster") + async with primary: + await primary.set("fubar", 1) + assert await primary.get("fubar") == _s("1") new_primary = client.primary_for("mymaster") new_replica = client.replica_for("mymaster") @@ -286,28 +192,22 @@ async def test_sentinel_cache(self, client, client_arguments, mocker, _s): assert new_primary.cache assert new_replica.cache - await new_primary.ping() - await new_replica.ping() - - replica_spy = mocker.spy(coredis.BaseConnection, "create_request") - - assert new_primary.cache.healthy - assert new_replica.cache.healthy + async with new_primary, new_replica: + await new_primary.ping() + await new_replica.ping() - assert await new_primary.get("fubar") == _s("1") - assert await new_replica.get("fubar") == _s("1") - - assert replica_spy.call_count == 0 + assert await new_primary.get("fubar") == _s("1") + create_request_spy = mocker.spy(coredis.BaseConnection, "create_request") + assert await new_replica.get("fubar") == _s("1") + assert create_request_spy.call_count == 0 @pytest.mark.xfail - async def test_replication(self, client): - with client.primary_for("mymaster").ensure_replication(1) as primary: - await primary.set("fubar", 1) - - with pytest.raises(ReplicationError): - with client.primary_for("mymaster").ensure_replication(2) as primary: + async def test_replication(self, client: Sentinel): + primary = client.primary_for("mymaster") + async with primary: + with primary.ensure_replication(1): await primary.set("fubar", 1) - with pytest.raises(ResponseError): - with client.replica_for("mymaster").ensure_replication(2) as replica: - await replica.set("fubar", 1) + with pytest.RaisesGroup(ReplicationError, allow_unwrapped=True, flatten_subgroups=True): + with primary.ensure_replication(2): + await primary.set("fubar", 1) diff --git a/tests/test_sidecar.py b/tests/test_sidecar.py deleted file mode 100644 index 5110f8b85..000000000 --- a/tests/test_sidecar.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -import asyncio - -import pytest - -from coredis._sidecar import Sidecar -from tests.conftest import targets - -pytestmark = pytest.mark.asyncio - - -@targets("redis_basic", "redis_basic_blocking", "redis_basic_raw") -class TestSidecar: - async def test_noop_sidecar(self, client): - sidecar = Sidecar(set(), health_check_interval_seconds=1) - assert sidecar.connection is None - await sidecar.start(client) - assert sidecar.connection is not None - await asyncio.sleep(0.1) - sidecar.stop() - assert sidecar.last_checkin > 0 - assert sidecar.connection is None - - async def test_pubsub_sidecar(self, client, _s): - sidecar = Sidecar({b"subscribe", b"message"}, health_check_interval_seconds=1) - assert sidecar.connection is None - await sidecar.start(client) - assert sidecar.connection is not None - await sidecar.connection.send_command(b"SUBSCRIBE", b"fubar") - await client.publish("fubar", "test") - m1 = await sidecar.messages.get() - m2 = await sidecar.messages.get() - assert m1[0] == b"subscribe" - assert m2[0] == b"message" - sidecar.stop() - - async def test_sidecar_reconnect(self, client, _s): - sidecar = Sidecar(set(), health_check_interval_seconds=1) - assert sidecar.connection is None - await sidecar.start(client) - assert sidecar.connection is not None - sidecar.connection.disconnect() - assert not sidecar.connection.is_connected - await asyncio.sleep(0.5) - assert sidecar.connection is not None - assert sidecar.connection.is_connected - sidecar.stop() - - async def test_finalization(self, client, cloner): - running_tasks = asyncio.all_tasks() - - async def scoped_client(): - clone = await cloner(client) - sidecar = Sidecar(set(), health_check_interval_seconds=1) - await sidecar.start(clone) - - await scoped_client() - await asyncio.sleep(0.1) - assert set() == asyncio.all_tasks() - running_tasks diff --git a/tests/test_stream_consumers.py b/tests/test_stream_consumers.py index d897bca16..6f8a34889 100644 --- a/tests/test_stream_consumers.py +++ b/tests/test_stream_consumers.py @@ -1,9 +1,8 @@ from __future__ import annotations -import asyncio -import threading from collections import OrderedDict +import anyio import pytest from coredis.exceptions import StreamConsumerInitializationError @@ -20,13 +19,7 @@ async def consume_entries(consumer, count, consumed=None): return consumed -@targets( - "redis_basic", - "redis_basic_blocking", - "redis_basic_raw", - "redis_cluster", - "redis_cluster_raw", -) +@targets("redis_basic", "redis_basic_raw", "redis_cluster", "redis_cluster_raw") class TestStreamConsumers: async def test_single_consumer(self, client, _s): consumer = await Consumer(client, ["a", "b"]) @@ -146,21 +139,23 @@ async def test_single_group_consumer_auto_create_group_stream(self, client, _s): ] async def test_multiple_group_consumer_auto_create_group_stream(self, client, cloner, _s): - client_2 = await cloner(client) - consumer_1 = await GroupConsumer( - client, ["a", "b"], "group-a", "consumer-1", auto_create=True - ) - consumer_2 = await GroupConsumer( - client_2, ["a", "b"], "group-a", "consumer-2", auto_create=True - ) - [await client.xadd("a", {"id": i}) for i in range(10)] - [await client.xadd("b", {"id": i}) for i in range(10, 20)] - consumed = await consume_entries(consumer_1, 20) - consumed = await consume_entries(consumer_2, 20, consumed) - assert list(range(10)) == [int(entry.field_values[_s("id")]) for entry in consumed[_s("a")]] - assert list(range(10, 20)) == [ - int(entry.field_values[_s("id")]) for entry in consumed[_s("b")] - ] + async with await cloner(client) as client_2: + consumer_1 = await GroupConsumer( + client, ["a", "b"], "group-a", "consumer-1", auto_create=True + ) + consumer_2 = await GroupConsumer( + client_2, ["a", "b"], "group-a", "consumer-2", auto_create=True + ) + [await client.xadd("a", {"id": i}) for i in range(10)] + [await client.xadd("b", {"id": i}) for i in range(10, 20)] + consumed = await consume_entries(consumer_1, 20) + consumed = await consume_entries(consumer_2, 20, consumed) + assert list(range(10)) == sorted( + int(e.field_values[_s("id")]) for e in consumed[_s("a")] + ) + assert list(range(10, 20)) == sorted( + int(e.field_values[_s("id")]) for e in consumed[_s("b")] + ) async def test_group_consumer_start_from_pending_list(self, client, _s): consumer = await GroupConsumer( @@ -234,38 +229,40 @@ async def test_group_consumer_buffered(self, client, _s): async def test_single_blocking_consumer(self, client, cloner, _s): consumer = await Consumer(client, ["a"], timeout=1000) - clone = await cloner(client) - async def _inner(): - await asyncio.sleep(0.2) - await clone.xadd("a", {"id": 1}) + async with await cloner(client) as clone: - th = threading.Thread( - target=asyncio.run_coroutine_threadsafe, - args=(_inner(), asyncio.get_running_loop()), - ) - th.start() - _, entry = await consumer.get_entry() - th.join() + async def delayed_add(): + await anyio.sleep(0.05) + await clone.xadd("a", {"id": 1}) + + async with anyio.create_task_group() as tg: + tg.start_soon(delayed_add) + result = await consumer.get_entry() + tg.cancel_scope.cancel() + + assert result is not None and result[1] is not None + _, entry = result assert entry.field_values[_s("id")] == _s(1) async def test_group_blocking_consumer(self, client, cloner, _s): consumer = await GroupConsumer( client, ["a"], "group-a", "consumer-a", auto_create=True, timeout=1000 ) - clone = await cloner(client) - async def _inner(): - await asyncio.sleep(0.2) - await clone.xadd("a", {"id": 1}) + async with await cloner(client) as clone: - th = threading.Thread( - target=asyncio.run_coroutine_threadsafe, - args=(_inner(), asyncio.get_running_loop()), - ) - th.start() - _, entry = await consumer.get_entry() - th.join() + async def delayed_add(): + await anyio.sleep(0.05) + await clone.xadd("a", {"id": 1}) + + async with anyio.create_task_group() as tg: + tg.start_soon(delayed_add) + result = await consumer.get_entry() + tg.cancel_scope.cancel() + + assert result is not None and result[1] is not None + _, entry = result assert entry.field_values[_s("id")] == _s(1) async def test_single_non_blocking_iterator(self, client, _s): @@ -280,22 +277,20 @@ async def test_single_non_blocking_iterator(self, client, _s): async def test_single_blocking_iterator(self, client, cloner, _s): consumer = await Consumer(client, ["a"], timeout=1000) - clone = await cloner(client) - async def _inner(): - await asyncio.sleep(0.2) - await clone.xadd("a", {"id": 1}) + async with await cloner(client) as clone: - th = threading.Thread( - target=asyncio.run_coroutine_threadsafe, - args=(_inner(), asyncio.get_running_loop()), - ) - th.start() - consumed = {} + async def delayed_add(): + await anyio.sleep(0.05) + await clone.xadd("a", {"id": 1}) + + consumed = {} + async with anyio.create_task_group() as tg: + tg.start_soon(delayed_add) + async for stream, entry in consumer: + consumed.setdefault(stream, []).append(entry) + tg.cancel_scope.cancel() - async for stream, entry in consumer: - consumed.setdefault(stream, []).append(entry) - th.join() assert len(consumed[_s("a")]) == 1 assert _s(1) == consumed[_s("a")][0].field_values[_s("id")] @@ -303,21 +298,19 @@ async def test_group_blocking_iterator(self, client, cloner, _s): consumer = await GroupConsumer( client, ["a"], "group-a", "consumer-a", auto_create=True, timeout=1000 ) - clone = await cloner(client) - async def _inner(): - await asyncio.sleep(0.2) - await clone.xadd("a", {"id": 1}) + async with await cloner(client) as clone: - th = threading.Thread( - target=asyncio.run_coroutine_threadsafe, - args=(_inner(), asyncio.get_running_loop()), - ) - th.start() - consumed = {} + async def delayed_add(): + await anyio.sleep(0.05) + await clone.xadd("a", {"id": 1}) + + consumed = {} + async with anyio.create_task_group() as tg: + tg.start_soon(delayed_add) + async for stream, entry in consumer: + consumed.setdefault(stream, []).append(entry) + tg.cancel_scope.cancel() - async for stream, entry in consumer: - consumed.setdefault(stream, []).append(entry) - th.join() assert len(consumed[_s("a")]) == 1 assert _s(1) == consumed[_s("a")][0].field_values[_s("id")] diff --git a/tests/test_tracking_cache.py b/tests/test_tracking_cache.py index 418350ac6..832e2a8e2 100644 --- a/tests/test_tracking_cache.py +++ b/tests/test_tracking_cache.py @@ -1,62 +1,49 @@ from __future__ import annotations -import asyncio +from contextlib import AsyncExitStack import pytest +from anyio import sleep -from coredis.cache import ClusterTrackingCache, NodeTrackingCache, TrackingCache +from coredis.cache import LRUCache +from coredis.client.basic import Redis from tests.conftest import targets class CommonExamples: - @property - def cache(self): - return TrackingCache - - async def test_single_entry_cache(self, client, cloner, _s): + async def test_single_entry_cache(self, client: Redis, cloner, _s): await client.flushall() - cache = self.cache(max_keys=1, max_size_bytes=-1) - cached = await cloner(client, cache=cache) - assert not await cached.get("fubar") - await client.set("fubar", 1) - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("1") - await client.incr("fubar") - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("2") - cache.reset() - assert await cached.get("fubar") == _s("2") - - @pytest.mark.nopypy - async def test_max_size(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=1) - cached = await cloner(client, cache=cache) - await client.set("fubar", 1) - assert _s(1) == await cached.get("fubar") - assert _s(1) == await cached.get("fubar") - - @pytest.mark.pypyonly - async def test_max_size_skipped(self, client, cloner, _s): - with pytest.raises(RuntimeError): - self.cache(max_keys=1, max_size_bytes=1) + cache = LRUCache(max_keys=1) + cached: Redis = await cloner(client, cache=cache) + async with cached: + assert not await cached.get("fubar") + await client.incr("fubar") + await sleep(0.2) + assert await cached.get("fubar") == _s("1") + await client.incr("fubar") + await sleep(0.2) + assert await cached.get("fubar") == _s("2") + cache.reset() + assert await cached.get("fubar") == _s("2") async def test_eviction(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=-1) + cache = LRUCache(max_keys=1) cached = await cloner(client, cache=cache) - assert not await cached.get("fubar") - assert not await cached.get("barbar") - assert not await cached.get("fubar") - assert not await cached.get("barbar") - await client.set("fubar", 1) - await client.set("barbar", 2) - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("1") - assert await cached.get("barbar") == _s("2") - await client.pexpire("fubar", 1) - await client.pexpire("barbar", 1) - await asyncio.sleep(0.2) - assert not await cached.get("fubar") - assert not await cached.get("barbar") + async with cached: + assert not await cached.get("fubar") + assert not await cached.get("barbar") + assert not await cached.get("fubar") + assert not await cached.get("barbar") + await client.set("fubar", 1) + await client.set("barbar", 2) + await sleep(0.2) + assert await cached.get("fubar") == _s("1") + assert await cached.get("barbar") == _s("2") + await client.pexpire("fubar", 1) + await client.pexpire("barbar", 1) + await sleep(0.2) + assert not await cached.get("fubar") + assert not await cached.get("barbar") @pytest.mark.parametrize( "confidence, expectation", @@ -66,251 +53,147 @@ async def test_eviction(self, client, cloner, _s): (90, 25), ], ) - async def test_confidence(self, client, cloner, mocker, _s, confidence, expectation): - cache = self.cache(confidence=confidence, max_size_bytes=-1) + async def test_confidence(self, client: Redis, cloner, mocker, _s, confidence, expectation): + cache = LRUCache(confidence=confidence) cached = await cloner(client, cache=cache) - [await client.set(f"fubar{i}", i) for i in range(100)] - create_request = mocker.spy(cached.connection_pool.connection_class, "create_request") - [await cached.get(f"fubar{i}") for i in range(100)] - assert create_request.call_count == 100 - [await cached.get(f"fubar{i}") for i in range(100)] - assert create_request.call_count < 100 + expectation + async with cached: + await cached.ping() + [await client.set(f"fubar{i}", i) for i in range(100)] + create_request = mocker.spy(cached.connection_pool.connection_class, "create_request") + [await cached.get(f"fubar{i}") for i in range(100)] + assert create_request.call_count >= 100 + [await cached.get(f"fubar{i}") for i in range(100)] + assert create_request.call_count < 100 + expectation async def test_feedback(self, client, cloner, mocker, _s): - cache = self.cache(confidence=0, max_size_bytes=-1) + cache = LRUCache(confidence=0) cached = await cloner(client, cache=cache) - [await client.set(f"fubar{i}", i) for i in range(10)] + async with cached: + [await client.set(f"fubar{i}", i) for i in range(10)] - feedback = mocker.spy(cache, "feedback") - get = mocker.patch.object(cache, "get") - get.return_value = _s("11") + feedback = mocker.spy(cache, "feedback") + get = mocker.patch.object(cache, "get") + get.return_value = _s("11") - [await cached.get(f"fubar{i}") for i in range(10)] - assert feedback.call_count == 10 + [await cached.get(f"fubar{i}") for i in range(10)] + assert feedback.call_count == 10 async def test_feedback_adjust(self, client, cloner, mocker, _s): - cache = self.cache(confidence=50, dynamic_confidence=True, max_size_bytes=-1) + cache = LRUCache(confidence=50, dynamic_confidence=True) cached = await cloner(client, cache=cache) - [await client.set(f"fubar{i}", i) for i in range(100)] - [await cached.get(f"fubar{i}") for i in range(100)] + async with cached: + [await client.set(f"fubar{i}", i) for i in range(100)] + [await cached.get(f"fubar{i}") for i in range(100)] - feedback = mocker.spy(cache, "feedback") - original_get = cache.get - get = mocker.patch.object(cache, "get") - get.side_effect = lambda *_: _s("11") + feedback = mocker.spy(cache, "feedback") + original_get = cache.get + get = mocker.patch.object(cache, "get") + get.side_effect = lambda *_: _s("11") - [await cached.get(f"fubar{i}") for i in range(100)] - assert feedback.call_count > 0 - assert cache.confidence < 50 - dropped = float(cache.confidence) - mocker.resetall() - get.side_effect = original_get + [await cached.get(f"fubar{i}") for i in range(100)] + assert feedback.call_count > 0 + assert cache.confidence < 50 + dropped = float(cache.confidence) + mocker.resetall() + get.side_effect = original_get - [await cached.get(f"fubar{i}") for i in range(100)] - assert cache.confidence > dropped - cache.reset() - assert cache.confidence == 50 + [await cached.get(f"fubar{i}") for i in range(100)] + assert cache.confidence > dropped + cache.reset() + assert cache.confidence == 50 async def test_shared_cache(self, client, cloner, mocker, _s): - cache = self.cache(max_size_bytes=-1) + cache = LRUCache() cached = await cloner(client, cache=cache) clones = [await cloner(client, cache=cache) for _ in range(5)] - [await clone.ping() for clone in clones] - await client.set("fubar", "test") - await cached.get("fubar") - spy = mocker.spy(clones[0].connection_pool.connection_class, "create_request") - assert {await clone.get("fubar") for clone in clones} == {_s("test")} - assert spy.call_count == 0, spy.call_args - - await client.set("fubar", "fubar") - await asyncio.sleep(0.1) - assert {await clone.get("fubar") for clone in clones} == {_s("fubar")} - assert spy.call_count < 5, spy.call_args + async with AsyncExitStack() as stack: + await stack.enter_async_context(cached) + for c in clones: + await stack.enter_async_context(c) + [await clone.ping() for clone in clones] + await client.set("fubar", "test") + await cached.get("fubar") + spy = mocker.spy(clones[0].connection_pool.connection_class, "create_request") + assert {await clone.get("fubar") for clone in clones} == {_s("test")} + assert spy.call_count == 0, spy.call_args + + await client.set("fubar", "fubar") + await sleep(0.1) + assert {await clone.get("fubar") for clone in clones} == {_s("fubar")} + assert spy.call_count < 5, spy.call_args async def test_stats(self, client, cloner, mocker, _s): - cache = self.cache(confidence=0, max_size_bytes=-1) + cache = LRUCache(confidence=0) cached = await cloner(client, cache=cache) - await client.set("barbar", "test") - await cached.get("fubar") - await cached.get("fubar") - await client.set("fubar", "test") - await asyncio.sleep(0.01) - await cached.get("fubar") - await cached.get("fubar") - await cached.get("barbar") - await cached.get("barbar") - - get = mocker.patch.object(cache, "get") - get.side_effect = lambda *_: _s("dirty") - - await cached.get("barbar") - - assert sum(cache.stats.hits.values()) == 3 - assert sum(cache.stats.misses.values()) == 3 - assert sum(cache.stats.invalidations.values()) == 2 - assert sum(cache.stats.dirty.values()) == 1 - - assert cache.stats.hits[b"fubar"] == 2 - assert cache.stats.hits[b"barbar"] == 1 - - cache.stats.compact() - - assert sum(cache.stats.hits.values()) == 3 - assert sum(cache.stats.misses.values()) == 3 - assert sum(cache.stats.invalidations.values()) == 2 - - assert b"fubar" not in cache.stats.hits - assert b"barbar" not in cache.stats.hits - - assert cache.stats.summary == { - "hits": 3, - "misses": 3, - "invalidations": 2, - "dirty_hits": 1, - } - - cache.stats.clear() - assert cache.stats.summary == { - "hits": 0, - "misses": 0, - "invalidations": 0, - "dirty_hits": 0, - } - - -@targets( - "redis_basic", - "redis_basic_blocking", - "redis_basic_raw", -) -class TestProxyInvalidatingCache(CommonExamples): + async with cached: + await client.set("barbar", "test") + await cached.get("fubar") + await cached.get("fubar") + await client.set("fubar", "test") + await sleep(0.01) + await cached.get("fubar") + await cached.get("fubar") + await cached.get("barbar") + await cached.get("barbar") + + get = mocker.patch.object(cache, "get") + get.side_effect = lambda *_: _s("dirty") + + await cached.get("barbar") + + assert sum(cache.stats.hits.values()) == 3 + assert sum(cache.stats.misses.values()) == 3 + assert sum(cache.stats.invalidations.values()) == 2 + assert sum(cache.stats.dirty.values()) == 1 + + cache.stats.compact() + + assert sum(cache.stats.hits.values()) == 3 + assert sum(cache.stats.misses.values()) == 3 + assert sum(cache.stats.invalidations.values()) == 2 + + assert b"fubar" not in cache.stats.hits + assert b"barbar" not in cache.stats.hits + + assert cache.stats.summary == { + "hits": 3, + "misses": 3, + "invalidations": 2, + "dirty_hits": 1, + } + + cache.stats.clear() + assert cache.stats.summary == { + "hits": 0, + "misses": 0, + "invalidations": 0, + "dirty_hits": 0, + } + + +@targets("redis_basic", "redis_basic_raw") +class TestInvalidatingCache(CommonExamples): async def test_uninitialized_cache(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_idle_seconds=1, max_size_bytes=-1) - assert not cache.get_client_id(await client.connection_pool.get_connection()) + cache = LRUCache(max_keys=1) assert cache.confidence == 100 - _ = await cloner(client, cache=cache) - assert cache.get_client_id(await client.connection_pool.get_connection()) > 0 - - async def test_single_entry_cache_tracker_disconnected(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=-1) cached = await cloner(client, cache=cache) - assert not await client.get("fubar") - await client.set("fubar", 1) - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("1") - await client.incr("fubar") - cache.instance.connection.disconnect() - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("2") + async with cached: + assert cached.cache.get_client_id(cached) + await sleep(0.2) # can be flaky if we close immediately @targets( "redis_cluster", "redis_cluster_raw", ) -class TestClusterProxyInvalidatingCache(CommonExamples): - async def test_uninitialized_cache(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=-1) - assert not cache.get_client_id(await client.connection_pool.get_random_connection()) - assert cache.confidence == 100 - _ = await cloner(client, cache=cache) - assert cache.get_client_id(await client.connection_pool.get_random_connection()) > 0 - - async def test_single_entry_cache_tracker_disconnected(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=-1) - cached = await cloner(client, cache=cache) - assert not await client.get("fubar") - await client.set("fubar", 1) - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("1") - await client.incr("fubar") - [ncache.connection.disconnect() for ncache in cache.instance.node_caches.values()] - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("2") - - async def test_reinitialize_cluster(self, client, cloner, _s): - await client.set("fubar", 1) - cache = self.cache(max_keys=1, max_idle_seconds=1, max_size_bytes=-1) - cached = await cloner(client, cache=cache) - pre = dict(cached.cache.instance.node_caches) - assert await cached.get("fubar") == _s("1") - cached.connection_pool.disconnect() - cached.connection_pool.reset() - await asyncio.sleep(0.1) - assert await cached.get("fubar") == _s("1") - post = cached.cache.instance.node_caches - assert pre != post - - -@targets( - "redis_basic", - "redis_basic_raw", -) -class TestNodeInvalidatingCache(CommonExamples): - @property - def cache(self): - return NodeTrackingCache - - async def test_uninitialized_cache(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_idle_seconds=1, max_size_bytes=-1) - assert not cache.get_client_id(await client.connection_pool.get_connection()) - assert cache.confidence == 100 - _ = await cloner(client, cache=cache) - assert cache.get_client_id(await client.connection_pool.get_connection()) > 0 - - async def test_single_entry_cache_tracker_disconnected(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=-1) - cached = await cloner(client, cache=cache) - assert not await client.get("fubar") - await client.set("fubar", 1) - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("1") - await client.incr("fubar") - cache.connection.disconnect() - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("2") - - -@targets( - "redis_cluster", - "redis_cluster_blocking", - "redis_cluster_raw", -) class TestClusterInvalidatingCache(CommonExamples): - @property - def cache(self): - return ClusterTrackingCache - async def test_uninitialized_cache(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=-1) - assert not cache.get_client_id(await client.connection_pool.get_random_connection()) + cache = LRUCache(max_keys=1) assert cache.confidence == 100 - _ = await cloner(client, cache=cache) - assert cache.get_client_id(await client.connection_pool.get_random_connection()) > 0 - - async def test_single_entry_cache_tracker_disconnected(self, client, cloner, _s): - cache = self.cache(max_keys=1, max_size_bytes=-1) - cached = await cloner(client, cache=cache) - assert not await client.get("fubar") - await client.set("fubar", 1) - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("1") - await client.incr("fubar") - [ncache.connection.disconnect() for ncache in cache.node_caches.values()] - await asyncio.sleep(0.2) - assert await cached.get("fubar") == _s("2") - - async def test_reinitialize_cluster(self, client, cloner, _s): - await client.set("fubar", 1) - cache = self.cache(max_keys=1, max_idle_seconds=1, max_size_bytes=-1) cached = await cloner(client, cache=cache) - pre = dict(cached.cache.node_caches) - assert await cached.get("fubar") == _s("1") - cached.connection_pool.disconnect() - cached.connection_pool.reset() - await asyncio.sleep(0.1) - assert await cached.get("fubar") == _s("1") - post = cached.cache.node_caches - assert pre != post + async with cached: + assert ( + cached.cache.get_client_id(await client.connection_pool.get_random_connection()) > 0 + ) diff --git a/uv.lock b/uv.lock index 6e8755d3b..f762f86da 100644 --- a/uv.lock +++ b/uv.lock @@ -164,17 +164,16 @@ wheels = [ [[package]] name = "anyio" -version = "4.11.0" +version = "4.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "idna" }, - { name = "sniffio" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } +sdist = { url = "https://files.pythonhosted.org/packages/16/ce/8a777047513153587e5434fd752e89334ac33e379aa3497db860eeb60377/anyio-4.12.0.tar.gz", hash = "sha256:73c693b567b0c55130c104d0b43a9baf3aa6a31fc6110116509f27bf75e21ec0", size = 228266, upload-time = "2025-11-28T23:37:38.911Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, + { url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" }, ] [[package]] @@ -473,11 +472,11 @@ wheels = [ name = "coredis" source = { editable = "." } dependencies = [ - { name = "async-timeout" }, + { name = "anyio" }, { name = "beartype" }, { name = "deprecated" }, + { name = "exceptiongroup" }, { name = "packaging" }, - { name = "pympler" }, { name = "typing-extensions" }, ] @@ -513,9 +512,12 @@ ci = [ { name = "pytest-rerunfailures" }, { name = "pytest-reverse" }, { name = "pytest-sentry" }, + { name = "pytest-timeout" }, { name = "redis" }, { name = "ruff" }, + { name = "trio" }, { name = "types-deprecated" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, ] dev = [ { name = "aiobotocore" }, @@ -542,7 +544,9 @@ dev = [ { name = "pytest-reverse" }, { name = "redis" }, { name = "ruff" }, + { name = "trio" }, { name = "types-deprecated" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, ] docs = [ { name = "aiobotocore" }, @@ -582,7 +586,12 @@ docs = [ { name = "sphinx-sitemap" }, { name = "sphinxcontrib-programoutput" }, { name = "sphinxext-opengraph" }, + { name = "trio" }, { name = "types-deprecated" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, +] +orjson = [ + { name = "orjson" }, ] test = [ { name = "aiobotocore" }, @@ -602,17 +611,19 @@ test = [ { name = "pytest-mock" }, { name = "pytest-reverse" }, { name = "redis" }, + { name = "trio" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, ] [package.metadata] requires-dist = [ { name = "aiobotocore", marker = "extra == 'recipes'", specifier = ">=2.15.2" }, - { name = "async-timeout", specifier = ">4,<6" }, + { name = "anyio", specifier = ">=4.11.0" }, { name = "asyncache", marker = "extra == 'recipes'", specifier = ">=0.3.1" }, { name = "beartype", specifier = ">=0.20" }, { name = "deprecated", specifier = ">=1.2" }, + { name = "exceptiongroup", specifier = ">=1.3.0" }, { name = "packaging", specifier = ">=21,<26" }, - { name = "pympler", specifier = ">1,<2" }, { name = "typing-extensions", specifier = ">=4.13" }, ] provides-extras = ["recipes"] @@ -642,10 +653,13 @@ ci = [ { name = "pytest-rerunfailures" }, { name = "pytest-reverse" }, { name = "pytest-sentry" }, + { name = "pytest-timeout" }, { name = "redis" }, { name = "redis", specifier = ">=4.2.0" }, { name = "ruff" }, + { name = "trio", specifier = ">=0.31.0" }, { name = "types-deprecated" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, ] dev = [ { name = "aiobotocore", specifier = ">=2.15.2" }, @@ -672,7 +686,9 @@ dev = [ { name = "redis" }, { name = "redis", specifier = ">=4.2.0" }, { name = "ruff" }, + { name = "trio", specifier = ">=0.31.0" }, { name = "types-deprecated" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, ] docs = [ { name = "aiobotocore", specifier = ">=2.15.2" }, @@ -711,8 +727,11 @@ docs = [ { name = "sphinx-sitemap", specifier = "==2.8.0" }, { name = "sphinxcontrib-programoutput", specifier = "==0.18" }, { name = "sphinxext-opengraph", specifier = "==0.13.0" }, + { name = "trio", specifier = ">=0.31.0" }, { name = "types-deprecated" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, ] +orjson = [{ name = "orjson" }] test = [ { name = "aiobotocore", specifier = ">=2.15.2" }, { name = "asyncache", specifier = ">=0.3.1" }, @@ -730,6 +749,8 @@ test = [ { name = "pytest-mock" }, { name = "pytest-reverse" }, { name = "redis" }, + { name = "trio", specifier = ">=0.31.0" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy'" }, ] [[package]] @@ -936,7 +957,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -1738,6 +1759,99 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/11/0cc63f9f321ccf63886ac203336777140011fb669e739da36d8db3c53b98/numpy-2.3.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:2e267c7da5bf7309670523896df97f93f6e469fb931161f483cd6882b3b1a5dc", size = 12971844, upload-time = "2025-09-09T15:58:57.359Z" }, ] +[[package]] +name = "orjson" +version = "3.11.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/b8/333fdb27840f3bf04022d21b654a35f58e15407183aeb16f3b41aa053446/orjson-3.11.5.tar.gz", hash = "sha256:82393ab47b4fe44ffd0a7659fa9cfaacc717eb617c93cde83795f14af5c2e9d5", size = 5972347, upload-time = "2025-12-06T15:55:39.458Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/19/b22cf9dad4db20c8737041046054cbd4f38bb5a2d0e4bb60487832ce3d76/orjson-3.11.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:df9eadb2a6386d5ea2bfd81309c505e125cfc9ba2b1b99a97e60985b0b3665d1", size = 245719, upload-time = "2025-12-06T15:53:43.877Z" }, + { url = "https://files.pythonhosted.org/packages/03/2e/b136dd6bf30ef5143fbe76a4c142828b55ccc618be490201e9073ad954a1/orjson-3.11.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccc70da619744467d8f1f49a8cadae5ec7bbe054e5232d95f92ed8737f8c5870", size = 132467, upload-time = "2025-12-06T15:53:45.379Z" }, + { url = "https://files.pythonhosted.org/packages/ae/fc/ae99bfc1e1887d20a0268f0e2686eb5b13d0ea7bbe01de2b566febcd2130/orjson-3.11.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:073aab025294c2f6fc0807201c76fdaed86f8fc4be52c440fb78fbb759a1ac09", size = 130702, upload-time = "2025-12-06T15:53:46.659Z" }, + { url = "https://files.pythonhosted.org/packages/6e/43/ef7912144097765997170aca59249725c3ab8ef6079f93f9d708dd058df5/orjson-3.11.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:835f26fa24ba0bb8c53ae2a9328d1706135b74ec653ed933869b74b6909e63fd", size = 135907, upload-time = "2025-12-06T15:53:48.487Z" }, + { url = "https://files.pythonhosted.org/packages/3f/da/24d50e2d7f4092ddd4d784e37a3fa41f22ce8ed97abc9edd222901a96e74/orjson-3.11.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:667c132f1f3651c14522a119e4dd631fad98761fa960c55e8e7430bb2a1ba4ac", size = 139935, upload-time = "2025-12-06T15:53:49.88Z" }, + { url = "https://files.pythonhosted.org/packages/02/4a/b4cb6fcbfff5b95a3a019a8648255a0fac9b221fbf6b6e72be8df2361feb/orjson-3.11.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42e8961196af655bb5e63ce6c60d25e8798cd4dfbc04f4203457fa3869322c2e", size = 137541, upload-time = "2025-12-06T15:53:51.226Z" }, + { url = "https://files.pythonhosted.org/packages/a5/99/a11bd129f18c2377c27b2846a9d9be04acec981f770d711ba0aaea563984/orjson-3.11.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75412ca06e20904c19170f8a24486c4e6c7887dea591ba18a1ab572f1300ee9f", size = 139031, upload-time = "2025-12-06T15:53:52.309Z" }, + { url = "https://files.pythonhosted.org/packages/64/29/d7b77d7911574733a036bb3e8ad7053ceb2b7d6ea42208b9dbc55b23b9ed/orjson-3.11.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6af8680328c69e15324b5af3ae38abbfcf9cbec37b5346ebfd52339c3d7e8a18", size = 141622, upload-time = "2025-12-06T15:53:53.606Z" }, + { url = "https://files.pythonhosted.org/packages/93/41/332db96c1de76b2feda4f453e91c27202cd092835936ce2b70828212f726/orjson-3.11.5-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a86fe4ff4ea523eac8f4b57fdac319faf037d3c1be12405e6a7e86b3fbc4756a", size = 413800, upload-time = "2025-12-06T15:53:54.866Z" }, + { url = "https://files.pythonhosted.org/packages/76/e1/5a0d148dd1f89ad2f9651df67835b209ab7fcb1118658cf353425d7563e9/orjson-3.11.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e607b49b1a106ee2086633167033afbd63f76f2999e9236f638b06b112b24ea7", size = 151198, upload-time = "2025-12-06T15:53:56.383Z" }, + { url = "https://files.pythonhosted.org/packages/0d/96/8db67430d317a01ae5cf7971914f6775affdcfe99f5bff9ef3da32492ecc/orjson-3.11.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7339f41c244d0eea251637727f016b3d20050636695bc78345cce9029b189401", size = 141984, upload-time = "2025-12-06T15:53:57.746Z" }, + { url = "https://files.pythonhosted.org/packages/71/49/40d21e1aa1ac569e521069228bb29c9b5a350344ccf922a0227d93c2ed44/orjson-3.11.5-cp310-cp310-win32.whl", hash = "sha256:8be318da8413cdbbce77b8c5fac8d13f6eb0f0db41b30bb598631412619572e8", size = 135272, upload-time = "2025-12-06T15:53:59.769Z" }, + { url = "https://files.pythonhosted.org/packages/c4/7e/d0e31e78be0c100e08be64f48d2850b23bcb4d4c70d114f4e43b39f6895a/orjson-3.11.5-cp310-cp310-win_amd64.whl", hash = "sha256:b9f86d69ae822cabc2a0f6c099b43e8733dda788405cba2665595b7e8dd8d167", size = 133360, upload-time = "2025-12-06T15:54:01.25Z" }, + { url = "https://files.pythonhosted.org/packages/fd/68/6b3659daec3a81aed5ab47700adb1a577c76a5452d35b91c88efee89987f/orjson-3.11.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9c8494625ad60a923af6b2b0bd74107146efe9b55099e20d7740d995f338fcd8", size = 245318, upload-time = "2025-12-06T15:54:02.355Z" }, + { url = "https://files.pythonhosted.org/packages/e9/00/92db122261425f61803ccf0830699ea5567439d966cbc35856fe711bfe6b/orjson-3.11.5-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:7bb2ce0b82bc9fd1168a513ddae7a857994b780b2945a8c51db4ab1c4b751ebc", size = 129491, upload-time = "2025-12-06T15:54:03.877Z" }, + { url = "https://files.pythonhosted.org/packages/94/4f/ffdcb18356518809d944e1e1f77589845c278a1ebbb5a8297dfefcc4b4cb/orjson-3.11.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67394d3becd50b954c4ecd24ac90b5051ee7c903d167459f93e77fc6f5b4c968", size = 132167, upload-time = "2025-12-06T15:54:04.944Z" }, + { url = "https://files.pythonhosted.org/packages/97/c6/0a8caff96f4503f4f7dd44e40e90f4d14acf80d3b7a97cb88747bb712d3e/orjson-3.11.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:298d2451f375e5f17b897794bcc3e7b821c0f32b4788b9bcae47ada24d7f3cf7", size = 130516, upload-time = "2025-12-06T15:54:06.274Z" }, + { url = "https://files.pythonhosted.org/packages/4d/63/43d4dc9bd9954bff7052f700fdb501067f6fb134a003ddcea2a0bb3854ed/orjson-3.11.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa5e4244063db8e1d87e0f54c3f7522f14b2dc937e65d5241ef0076a096409fd", size = 135695, upload-time = "2025-12-06T15:54:07.702Z" }, + { url = "https://files.pythonhosted.org/packages/87/6f/27e2e76d110919cb7fcb72b26166ee676480a701bcf8fc53ac5d0edce32f/orjson-3.11.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1db2088b490761976c1b2e956d5d4e6409f3732e9d79cfa69f876c5248d1baf9", size = 139664, upload-time = "2025-12-06T15:54:08.828Z" }, + { url = "https://files.pythonhosted.org/packages/d4/f8/5966153a5f1be49b5fbb8ca619a529fde7bc71aa0a376f2bb83fed248bcd/orjson-3.11.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2ed66358f32c24e10ceea518e16eb3549e34f33a9d51f99ce23b0251776a1ef", size = 137289, upload-time = "2025-12-06T15:54:09.898Z" }, + { url = "https://files.pythonhosted.org/packages/a7/34/8acb12ff0299385c8bbcbb19fbe40030f23f15a6de57a9c587ebf71483fb/orjson-3.11.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2021afda46c1ed64d74b555065dbd4c2558d510d8cec5ea6a53001b3e5e82a9", size = 138784, upload-time = "2025-12-06T15:54:11.022Z" }, + { url = "https://files.pythonhosted.org/packages/ee/27/910421ea6e34a527f73d8f4ee7bdffa48357ff79c7b8d6eb6f7b82dd1176/orjson-3.11.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b42ffbed9128e547a1647a3e50bc88ab28ae9daa61713962e0d3dd35e820c125", size = 141322, upload-time = "2025-12-06T15:54:12.427Z" }, + { url = "https://files.pythonhosted.org/packages/87/a3/4b703edd1a05555d4bb1753d6ce44e1a05b7a6d7c164d5b332c795c63d70/orjson-3.11.5-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:8d5f16195bb671a5dd3d1dbea758918bada8f6cc27de72bd64adfbd748770814", size = 413612, upload-time = "2025-12-06T15:54:13.858Z" }, + { url = "https://files.pythonhosted.org/packages/1b/36/034177f11d7eeea16d3d2c42a1883b0373978e08bc9dad387f5074c786d8/orjson-3.11.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c0e5d9f7a0227df2927d343a6e3859bebf9208b427c79bd31949abcc2fa32fa5", size = 150993, upload-time = "2025-12-06T15:54:15.189Z" }, + { url = "https://files.pythonhosted.org/packages/44/2f/ea8b24ee046a50a7d141c0227c4496b1180b215e728e3b640684f0ea448d/orjson-3.11.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23d04c4543e78f724c4dfe656b3791b5f98e4c9253e13b2636f1af5d90e4a880", size = 141774, upload-time = "2025-12-06T15:54:16.451Z" }, + { url = "https://files.pythonhosted.org/packages/8a/12/cc440554bf8200eb23348a5744a575a342497b65261cd65ef3b28332510a/orjson-3.11.5-cp311-cp311-win32.whl", hash = "sha256:c404603df4865f8e0afe981aa3c4b62b406e6d06049564d58934860b62b7f91d", size = 135109, upload-time = "2025-12-06T15:54:17.73Z" }, + { url = "https://files.pythonhosted.org/packages/a3/83/e0c5aa06ba73a6760134b169f11fb970caa1525fa4461f94d76e692299d9/orjson-3.11.5-cp311-cp311-win_amd64.whl", hash = "sha256:9645ef655735a74da4990c24ffbd6894828fbfa117bc97c1edd98c282ecb52e1", size = 133193, upload-time = "2025-12-06T15:54:19.426Z" }, + { url = "https://files.pythonhosted.org/packages/cb/35/5b77eaebc60d735e832c5b1a20b155667645d123f09d471db0a78280fb49/orjson-3.11.5-cp311-cp311-win_arm64.whl", hash = "sha256:1cbf2735722623fcdee8e712cbaaab9e372bbcb0c7924ad711b261c2eccf4a5c", size = 126830, upload-time = "2025-12-06T15:54:20.836Z" }, + { url = "https://files.pythonhosted.org/packages/ef/a4/8052a029029b096a78955eadd68ab594ce2197e24ec50e6b6d2ab3f4e33b/orjson-3.11.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:334e5b4bff9ad101237c2d799d9fd45737752929753bf4faf4b207335a416b7d", size = 245347, upload-time = "2025-12-06T15:54:22.061Z" }, + { url = "https://files.pythonhosted.org/packages/64/67/574a7732bd9d9d79ac620c8790b4cfe0717a3d5a6eb2b539e6e8995e24a0/orjson-3.11.5-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:ff770589960a86eae279f5d8aa536196ebda8273a2a07db2a54e82b93bc86626", size = 129435, upload-time = "2025-12-06T15:54:23.615Z" }, + { url = "https://files.pythonhosted.org/packages/52/8d/544e77d7a29d90cf4d9eecd0ae801c688e7f3d1adfa2ebae5e1e94d38ab9/orjson-3.11.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed24250e55efbcb0b35bed7caaec8cedf858ab2f9f2201f17b8938c618c8ca6f", size = 132074, upload-time = "2025-12-06T15:54:24.694Z" }, + { url = "https://files.pythonhosted.org/packages/6e/57/b9f5b5b6fbff9c26f77e785baf56ae8460ef74acdb3eae4931c25b8f5ba9/orjson-3.11.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a66d7769e98a08a12a139049aac2f0ca3adae989817f8c43337455fbc7669b85", size = 130520, upload-time = "2025-12-06T15:54:26.185Z" }, + { url = "https://files.pythonhosted.org/packages/f6/6d/d34970bf9eb33f9ec7c979a262cad86076814859e54eb9a059a52f6dc13d/orjson-3.11.5-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:86cfc555bfd5794d24c6a1903e558b50644e5e68e6471d66502ce5cb5fdef3f9", size = 136209, upload-time = "2025-12-06T15:54:27.264Z" }, + { url = "https://files.pythonhosted.org/packages/e7/39/bc373b63cc0e117a105ea12e57280f83ae52fdee426890d57412432d63b3/orjson-3.11.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a230065027bc2a025e944f9d4714976a81e7ecfa940923283bca7bbc1f10f626", size = 139837, upload-time = "2025-12-06T15:54:28.75Z" }, + { url = "https://files.pythonhosted.org/packages/cb/aa/7c4818c8d7d324da220f4f1af55c343956003aa4d1ce1857bdc1d396ba69/orjson-3.11.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b29d36b60e606df01959c4b982729c8845c69d1963f88686608be9ced96dbfaa", size = 137307, upload-time = "2025-12-06T15:54:29.856Z" }, + { url = "https://files.pythonhosted.org/packages/46/bf/0993b5a056759ba65145effe3a79dd5a939d4a070eaa5da2ee3180fbb13f/orjson-3.11.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c74099c6b230d4261fdc3169d50efc09abf38ace1a42ea2f9994b1d79153d477", size = 139020, upload-time = "2025-12-06T15:54:31.024Z" }, + { url = "https://files.pythonhosted.org/packages/65/e8/83a6c95db3039e504eda60fc388f9faedbb4f6472f5aba7084e06552d9aa/orjson-3.11.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e697d06ad57dd0c7a737771d470eedc18e68dfdefcdd3b7de7f33dfda5b6212e", size = 141099, upload-time = "2025-12-06T15:54:32.196Z" }, + { url = "https://files.pythonhosted.org/packages/b9/b4/24fdc024abfce31c2f6812973b0a693688037ece5dc64b7a60c1ce69e2f2/orjson-3.11.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e08ca8a6c851e95aaecc32bc44a5aa75d0ad26af8cdac7c77e4ed93acf3d5b69", size = 413540, upload-time = "2025-12-06T15:54:33.361Z" }, + { url = "https://files.pythonhosted.org/packages/d9/37/01c0ec95d55ed0c11e4cae3e10427e479bba40c77312b63e1f9665e0737d/orjson-3.11.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e8b5f96c05fce7d0218df3fdfeb962d6b8cfff7e3e20264306b46dd8b217c0f3", size = 151530, upload-time = "2025-12-06T15:54:34.6Z" }, + { url = "https://files.pythonhosted.org/packages/f9/d4/f9ebc57182705bb4bbe63f5bbe14af43722a2533135e1d2fb7affa0c355d/orjson-3.11.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ddbfdb5099b3e6ba6d6ea818f61997bb66de14b411357d24c4612cf1ebad08ca", size = 141863, upload-time = "2025-12-06T15:54:35.801Z" }, + { url = "https://files.pythonhosted.org/packages/0d/04/02102b8d19fdcb009d72d622bb5781e8f3fae1646bf3e18c53d1bc8115b5/orjson-3.11.5-cp312-cp312-win32.whl", hash = "sha256:9172578c4eb09dbfcf1657d43198de59b6cef4054de385365060ed50c458ac98", size = 135255, upload-time = "2025-12-06T15:54:37.209Z" }, + { url = "https://files.pythonhosted.org/packages/d4/fb/f05646c43d5450492cb387de5549f6de90a71001682c17882d9f66476af5/orjson-3.11.5-cp312-cp312-win_amd64.whl", hash = "sha256:2b91126e7b470ff2e75746f6f6ee32b9ab67b7a93c8ba1d15d3a0caaf16ec875", size = 133252, upload-time = "2025-12-06T15:54:38.401Z" }, + { url = "https://files.pythonhosted.org/packages/dc/a6/7b8c0b26ba18c793533ac1cd145e131e46fcf43952aa94c109b5b913c1f0/orjson-3.11.5-cp312-cp312-win_arm64.whl", hash = "sha256:acbc5fac7e06777555b0722b8ad5f574739e99ffe99467ed63da98f97f9ca0fe", size = 126777, upload-time = "2025-12-06T15:54:39.515Z" }, + { url = "https://files.pythonhosted.org/packages/10/43/61a77040ce59f1569edf38f0b9faadc90c8cf7e9bec2e0df51d0132c6bb7/orjson-3.11.5-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:3b01799262081a4c47c035dd77c1301d40f568f77cc7ec1bb7db5d63b0a01629", size = 245271, upload-time = "2025-12-06T15:54:40.878Z" }, + { url = "https://files.pythonhosted.org/packages/55/f9/0f79be617388227866d50edd2fd320cb8fb94dc1501184bb1620981a0aba/orjson-3.11.5-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:61de247948108484779f57a9f406e4c84d636fa5a59e411e6352484985e8a7c3", size = 129422, upload-time = "2025-12-06T15:54:42.403Z" }, + { url = "https://files.pythonhosted.org/packages/77/42/f1bf1549b432d4a78bfa95735b79b5dac75b65b5bb815bba86ad406ead0a/orjson-3.11.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:894aea2e63d4f24a7f04a1908307c738d0dce992e9249e744b8f4e8dd9197f39", size = 132060, upload-time = "2025-12-06T15:54:43.531Z" }, + { url = "https://files.pythonhosted.org/packages/25/49/825aa6b929f1a6ed244c78acd7b22c1481fd7e5fda047dc8bf4c1a807eb6/orjson-3.11.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ddc21521598dbe369d83d4d40338e23d4101dad21dae0e79fa20465dbace019f", size = 130391, upload-time = "2025-12-06T15:54:45.059Z" }, + { url = "https://files.pythonhosted.org/packages/42/ec/de55391858b49e16e1aa8f0bbbb7e5997b7345d8e984a2dec3746d13065b/orjson-3.11.5-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cce16ae2f5fb2c53c3eafdd1706cb7b6530a67cc1c17abe8ec747f5cd7c0c51", size = 135964, upload-time = "2025-12-06T15:54:46.576Z" }, + { url = "https://files.pythonhosted.org/packages/1c/40/820bc63121d2d28818556a2d0a09384a9f0262407cf9fa305e091a8048df/orjson-3.11.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e46c762d9f0e1cfb4ccc8515de7f349abbc95b59cb5a2bd68df5973fdef913f8", size = 139817, upload-time = "2025-12-06T15:54:48.084Z" }, + { url = "https://files.pythonhosted.org/packages/09/c7/3a445ca9a84a0d59d26365fd8898ff52bdfcdcb825bcc6519830371d2364/orjson-3.11.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d7345c759276b798ccd6d77a87136029e71e66a8bbf2d2755cbdde1d82e78706", size = 137336, upload-time = "2025-12-06T15:54:49.426Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b3/dc0d3771f2e5d1f13368f56b339c6782f955c6a20b50465a91acb79fe961/orjson-3.11.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75bc2e59e6a2ac1dd28901d07115abdebc4563b5b07dd612bf64260a201b1c7f", size = 138993, upload-time = "2025-12-06T15:54:50.939Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a2/65267e959de6abe23444659b6e19c888f242bf7725ff927e2292776f6b89/orjson-3.11.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:54aae9b654554c3b4edd61896b978568c6daa16af96fa4681c9b5babd469f863", size = 141070, upload-time = "2025-12-06T15:54:52.414Z" }, + { url = "https://files.pythonhosted.org/packages/63/c9/da44a321b288727a322c6ab17e1754195708786a04f4f9d2220a5076a649/orjson-3.11.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:4bdd8d164a871c4ec773f9de0f6fe8769c2d6727879c37a9666ba4183b7f8228", size = 413505, upload-time = "2025-12-06T15:54:53.67Z" }, + { url = "https://files.pythonhosted.org/packages/7f/17/68dc14fa7000eefb3d4d6d7326a190c99bb65e319f02747ef3ebf2452f12/orjson-3.11.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a261fef929bcf98a60713bf5e95ad067cea16ae345d9a35034e73c3990e927d2", size = 151342, upload-time = "2025-12-06T15:54:55.113Z" }, + { url = "https://files.pythonhosted.org/packages/c4/c5/ccee774b67225bed630a57478529fc026eda33d94fe4c0eac8fe58d4aa52/orjson-3.11.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c028a394c766693c5c9909dec76b24f37e6a1b91999e8d0c0d5feecbe93c3e05", size = 141823, upload-time = "2025-12-06T15:54:56.331Z" }, + { url = "https://files.pythonhosted.org/packages/67/80/5d00e4155d0cd7390ae2087130637671da713959bb558db9bac5e6f6b042/orjson-3.11.5-cp313-cp313-win32.whl", hash = "sha256:2cc79aaad1dfabe1bd2d50ee09814a1253164b3da4c00a78c458d82d04b3bdef", size = 135236, upload-time = "2025-12-06T15:54:57.507Z" }, + { url = "https://files.pythonhosted.org/packages/95/fe/792cc06a84808dbdc20ac6eab6811c53091b42f8e51ecebf14b540e9cfe4/orjson-3.11.5-cp313-cp313-win_amd64.whl", hash = "sha256:ff7877d376add4e16b274e35a3f58b7f37b362abf4aa31863dadacdd20e3a583", size = 133167, upload-time = "2025-12-06T15:54:58.71Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/d158bd8b50e3b1cfdcf406a7e463f6ffe3f0d167b99634717acdaf5e299f/orjson-3.11.5-cp313-cp313-win_arm64.whl", hash = "sha256:59ac72ea775c88b163ba8d21b0177628bd015c5dd060647bbab6e22da3aad287", size = 126712, upload-time = "2025-12-06T15:54:59.892Z" }, + { url = "https://files.pythonhosted.org/packages/c2/60/77d7b839e317ead7bb225d55bb50f7ea75f47afc489c81199befc5435b50/orjson-3.11.5-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e446a8ea0a4c366ceafc7d97067bfd55292969143b57e3c846d87fc701e797a0", size = 245252, upload-time = "2025-12-06T15:55:01.127Z" }, + { url = "https://files.pythonhosted.org/packages/f1/aa/d4639163b400f8044cef0fb9aa51b0337be0da3a27187a20d1166e742370/orjson-3.11.5-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:53deb5addae9c22bbe3739298f5f2196afa881ea75944e7720681c7080909a81", size = 129419, upload-time = "2025-12-06T15:55:02.723Z" }, + { url = "https://files.pythonhosted.org/packages/30/94/9eabf94f2e11c671111139edf5ec410d2f21e6feee717804f7e8872d883f/orjson-3.11.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82cd00d49d6063d2b8791da5d4f9d20539c5951f965e45ccf4e96d33505ce68f", size = 132050, upload-time = "2025-12-06T15:55:03.918Z" }, + { url = "https://files.pythonhosted.org/packages/3d/c8/ca10f5c5322f341ea9a9f1097e140be17a88f88d1cfdd29df522970d9744/orjson-3.11.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3fd15f9fc8c203aeceff4fda211157fad114dde66e92e24097b3647a08f4ee9e", size = 130370, upload-time = "2025-12-06T15:55:05.173Z" }, + { url = "https://files.pythonhosted.org/packages/25/d4/e96824476d361ee2edd5c6290ceb8d7edf88d81148a6ce172fc00278ca7f/orjson-3.11.5-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9df95000fbe6777bf9820ae82ab7578e8662051bb5f83d71a28992f539d2cda7", size = 136012, upload-time = "2025-12-06T15:55:06.402Z" }, + { url = "https://files.pythonhosted.org/packages/85/8e/9bc3423308c425c588903f2d103cfcfe2539e07a25d6522900645a6f257f/orjson-3.11.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:92a8d676748fca47ade5bc3da7430ed7767afe51b2f8100e3cd65e151c0eaceb", size = 139809, upload-time = "2025-12-06T15:55:07.656Z" }, + { url = "https://files.pythonhosted.org/packages/e9/3c/b404e94e0b02a232b957c54643ce68d0268dacb67ac33ffdee24008c8b27/orjson-3.11.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa0f513be38b40234c77975e68805506cad5d57b3dfd8fe3baa7f4f4051e15b4", size = 137332, upload-time = "2025-12-06T15:55:08.961Z" }, + { url = "https://files.pythonhosted.org/packages/51/30/cc2d69d5ce0ad9b84811cdf4a0cd5362ac27205a921da524ff42f26d65e0/orjson-3.11.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa1863e75b92891f553b7922ce4ee10ed06db061e104f2b7815de80cdcb135ad", size = 138983, upload-time = "2025-12-06T15:55:10.595Z" }, + { url = "https://files.pythonhosted.org/packages/0e/87/de3223944a3e297d4707d2fe3b1ffb71437550e165eaf0ca8bbe43ccbcb1/orjson-3.11.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d4be86b58e9ea262617b8ca6251a2f0d63cc132a6da4b5fcc8e0a4128782c829", size = 141069, upload-time = "2025-12-06T15:55:11.832Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/81d5087ae74be33bcae3ff2d80f5ccaa4a8fedc6d39bf65a427a95b8977f/orjson-3.11.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:b923c1c13fa02084eb38c9c065afd860a5cff58026813319a06949c3af5732ac", size = 413491, upload-time = "2025-12-06T15:55:13.314Z" }, + { url = "https://files.pythonhosted.org/packages/d0/6f/f6058c21e2fc1efaf918986dbc2da5cd38044f1a2d4b7b91ad17c4acf786/orjson-3.11.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:1b6bd351202b2cd987f35a13b5e16471cf4d952b42a73c391cc537974c43ef6d", size = 151375, upload-time = "2025-12-06T15:55:14.715Z" }, + { url = "https://files.pythonhosted.org/packages/54/92/c6921f17d45e110892899a7a563a925b2273d929959ce2ad89e2525b885b/orjson-3.11.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:bb150d529637d541e6af06bbe3d02f5498d628b7f98267ff87647584293ab439", size = 141850, upload-time = "2025-12-06T15:55:15.94Z" }, + { url = "https://files.pythonhosted.org/packages/88/86/cdecb0140a05e1a477b81f24739da93b25070ee01ce7f7242f44a6437594/orjson-3.11.5-cp314-cp314-win32.whl", hash = "sha256:9cc1e55c884921434a84a0c3dd2699eb9f92e7b441d7f53f3941079ec6ce7499", size = 135278, upload-time = "2025-12-06T15:55:17.202Z" }, + { url = "https://files.pythonhosted.org/packages/e4/97/b638d69b1e947d24f6109216997e38922d54dcdcdb1b11c18d7efd2d3c59/orjson-3.11.5-cp314-cp314-win_amd64.whl", hash = "sha256:a4f3cb2d874e03bc7767c8f88adaa1a9a05cecea3712649c3b58589ec7317310", size = 133170, upload-time = "2025-12-06T15:55:18.468Z" }, + { url = "https://files.pythonhosted.org/packages/8f/dd/f4fff4a6fe601b4f8f3ba3aa6da8ac33d17d124491a3b804c662a70e1636/orjson-3.11.5-cp314-cp314-win_arm64.whl", hash = "sha256:38b22f476c351f9a1c43e5b07d8b5a02eb24a6ab8e75f700f7d479d4568346a5", size = 126713, upload-time = "2025-12-06T15:55:19.738Z" }, +] + +[[package]] +name = "outcome" +version = "1.3.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/df/77698abfac98571e65ffeb0c1fba8ffd692ab8458d617a0eed7d9a8d38f2/outcome-1.3.0.post0.tar.gz", hash = "sha256:9dcf02e65f2971b80047b377468e72a268e15c0af3cf1238e6ff14f7f91143b8", size = 21060, upload-time = "2023-10-26T04:26:04.361Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/8b/5ab7257531a5d830fc8000c476e63c935488d74609b50f9384a643ec0a62/outcome-1.3.0.post0-py2.py3-none-any.whl", hash = "sha256:e771c5ce06d1415e356078d3bdd68523f284b4ce5419828922b6871e65eda82b", size = 10692, upload-time = "2023-10-26T04:26:02.532Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -1902,18 +2016,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] -[[package]] -name = "pympler" -version = "1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pywin32", marker = "sys_platform == 'win32'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dd/37/c384631908029676d8e7213dd956bb686af303a80db7afbc9be36bc49495/pympler-1.1.tar.gz", hash = "sha256:1eaa867cb8992c218430f1708fdaccda53df064144d1c5656b1e6f1ee6000424", size = 179954, upload-time = "2024-06-28T19:56:06.563Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/79/4f/a6a2e2b202d7fd97eadfe90979845b8706676b41cbd3b42ba75adf329d1f/Pympler-1.1-py3-none-any.whl", hash = "sha256:5b223d6027d0619584116a0cbc28e8d2e378f7a79c1e5e024f9ff3b673c58506", size = 165766, upload-time = "2024-06-28T19:56:05.087Z" }, -] - [[package]] name = "pytest" version = "8.4.2" @@ -2037,37 +2139,27 @@ wheels = [ ] [[package]] -name = "python-dateutil" -version = "2.9.0.post0" +name = "pytest-timeout" +version = "2.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six" }, + { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, ] [[package]] -name = "pywin32" -version = "311" +name = "python-dateutil" +version = "2.9.0.post0" source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" }, - { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" }, - { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" }, - { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" }, - { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" }, - { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" }, - { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" }, - { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" }, - { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" }, - { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" }, - { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" }, - { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" }, - { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" }, - { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" }, - { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" }, + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] [[package]] @@ -2315,6 +2407,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/78/3565d011c61f5a43488987ee32b6f3f656e7f107ac2782dd57bdd7d91d9a/snowballstemmer-3.0.1-py3-none-any.whl", hash = "sha256:6cd7b3897da8d6c9ffb968a6781fa6532dce9c3618a4b127d920dab764a19064", size = 103274, upload-time = "2025-05-09T16:34:50.371Z" }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "soupsieve" version = "2.8" @@ -2629,6 +2730,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/75/8539d011f6be8e29f339c42e633aae3cb73bffa95dd0f9adec09b9c58e85/tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0", size = 38901, upload-time = "2025-06-05T07:13:43.546Z" }, ] +[[package]] +name = "trio" +version = "0.31.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "idna" }, + { name = "outcome" }, + { name = "sniffio" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/8f/c6e36dd11201e2a565977d8b13f0b027ba4593c1a80bed5185489178e257/trio-0.31.0.tar.gz", hash = "sha256:f71d551ccaa79d0cb73017a33ef3264fde8335728eb4c6391451fe5d253a9d5b", size = 605825, upload-time = "2025-09-09T15:17:15.242Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/5b/94237a3485620dbff9741df02ff6d8acaa5fdec67d81ab3f62e4d8511bf7/trio-0.31.0-py3-none-any.whl", hash = "sha256:b5d14cd6293d79298b49c3485ffd9c07e3ce03a6da8c7dfbe0cb3dd7dc9a4774", size = 512679, upload-time = "2025-09-09T15:17:13.821Z" }, +] + [[package]] name = "trove-classifiers" version = "2025.9.11.17" @@ -2729,6 +2848,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/cd/584a2ceb5532af99dd09e50919e3615ba99aa127e9850eafe5f31ddfdb9a/uvicorn-0.37.0-py3-none-any.whl", hash = "sha256:913b2b88672343739927ce381ff9e2ad62541f9f8289664fa1d1d3803fa2ce6c", size = 67976, upload-time = "2025-09-23T13:33:45.842Z" }, ] +[[package]] +name = "uvloop" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/06/f0/18d39dbd1971d6d62c4629cc7fa67f74821b0dc1f5a77af43719de7936a7/uvloop-0.22.1.tar.gz", hash = "sha256:6c84bae345b9147082b17371e3dd5d42775bddce91f885499017f4607fdaf39f", size = 2443250, upload-time = "2025-10-16T22:17:19.342Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/14/ecceb239b65adaaf7fde510aa8bd534075695d1e5f8dadfa32b5723d9cfb/uvloop-0.22.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ef6f0d4cc8a9fa1f6a910230cd53545d9a14479311e87e3cb225495952eb672c", size = 1343335, upload-time = "2025-10-16T22:16:11.43Z" }, + { url = "https://files.pythonhosted.org/packages/ba/ae/6f6f9af7f590b319c94532b9567409ba11f4fa71af1148cab1bf48a07048/uvloop-0.22.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7cd375a12b71d33d46af85a3343b35d98e8116134ba404bd657b3b1d15988792", size = 742903, upload-time = "2025-10-16T22:16:12.979Z" }, + { url = "https://files.pythonhosted.org/packages/09/bd/3667151ad0702282a1f4d5d29288fce8a13c8b6858bf0978c219cd52b231/uvloop-0.22.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac33ed96229b7790eb729702751c0e93ac5bc3bcf52ae9eccbff30da09194b86", size = 3648499, upload-time = "2025-10-16T22:16:14.451Z" }, + { url = "https://files.pythonhosted.org/packages/b3/f6/21657bb3beb5f8c57ce8be3b83f653dd7933c2fd00545ed1b092d464799a/uvloop-0.22.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:481c990a7abe2c6f4fc3d98781cc9426ebd7f03a9aaa7eb03d3bfc68ac2a46bd", size = 3700133, upload-time = "2025-10-16T22:16:16.272Z" }, + { url = "https://files.pythonhosted.org/packages/09/e0/604f61d004ded805f24974c87ddd8374ef675644f476f01f1df90e4cdf72/uvloop-0.22.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a592b043a47ad17911add5fbd087c76716d7c9ccc1d64ec9249ceafd735f03c2", size = 3512681, upload-time = "2025-10-16T22:16:18.07Z" }, + { url = "https://files.pythonhosted.org/packages/bb/ce/8491fd370b0230deb5eac69c7aae35b3be527e25a911c0acdffb922dc1cd/uvloop-0.22.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1489cf791aa7b6e8c8be1c5a080bae3a672791fcb4e9e12249b05862a2ca9cec", size = 3615261, upload-time = "2025-10-16T22:16:19.596Z" }, + { url = "https://files.pythonhosted.org/packages/c7/d5/69900f7883235562f1f50d8184bb7dd84a2fb61e9ec63f3782546fdbd057/uvloop-0.22.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c60ebcd36f7b240b30788554b6f0782454826a0ed765d8430652621b5de674b9", size = 1352420, upload-time = "2025-10-16T22:16:21.187Z" }, + { url = "https://files.pythonhosted.org/packages/a8/73/c4e271b3bce59724e291465cc936c37758886a4868787da0278b3b56b905/uvloop-0.22.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b7f102bf3cb1995cfeaee9321105e8f5da76fdb104cdad8986f85461a1b7b77", size = 748677, upload-time = "2025-10-16T22:16:22.558Z" }, + { url = "https://files.pythonhosted.org/packages/86/94/9fb7fad2f824d25f8ecac0d70b94d0d48107ad5ece03769a9c543444f78a/uvloop-0.22.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53c85520781d84a4b8b230e24a5af5b0778efdb39142b424990ff1ef7c48ba21", size = 3753819, upload-time = "2025-10-16T22:16:23.903Z" }, + { url = "https://files.pythonhosted.org/packages/74/4f/256aca690709e9b008b7108bc85fba619a2bc37c6d80743d18abad16ee09/uvloop-0.22.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56a2d1fae65fd82197cb8c53c367310b3eabe1bbb9fb5a04d28e3e3520e4f702", size = 3804529, upload-time = "2025-10-16T22:16:25.246Z" }, + { url = "https://files.pythonhosted.org/packages/7f/74/03c05ae4737e871923d21a76fe28b6aad57f5c03b6e6bfcfa5ad616013e4/uvloop-0.22.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40631b049d5972c6755b06d0bfe8233b1bd9a8a6392d9d1c45c10b6f9e9b2733", size = 3621267, upload-time = "2025-10-16T22:16:26.819Z" }, + { url = "https://files.pythonhosted.org/packages/75/be/f8e590fe61d18b4a92070905497aec4c0e64ae1761498cad09023f3f4b3e/uvloop-0.22.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:535cc37b3a04f6cd2c1ef65fa1d370c9a35b6695df735fcff5427323f2cd5473", size = 3723105, upload-time = "2025-10-16T22:16:28.252Z" }, + { url = "https://files.pythonhosted.org/packages/3d/ff/7f72e8170be527b4977b033239a83a68d5c881cc4775fca255c677f7ac5d/uvloop-0.22.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fe94b4564e865d968414598eea1a6de60adba0c040ba4ed05ac1300de402cd42", size = 1359936, upload-time = "2025-10-16T22:16:29.436Z" }, + { url = "https://files.pythonhosted.org/packages/c3/c6/e5d433f88fd54d81ef4be58b2b7b0cea13c442454a1db703a1eea0db1a59/uvloop-0.22.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:51eb9bd88391483410daad430813d982010f9c9c89512321f5b60e2cddbdddd6", size = 752769, upload-time = "2025-10-16T22:16:30.493Z" }, + { url = "https://files.pythonhosted.org/packages/24/68/a6ac446820273e71aa762fa21cdcc09861edd3536ff47c5cd3b7afb10eeb/uvloop-0.22.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:700e674a166ca5778255e0e1dc4e9d79ab2acc57b9171b79e65feba7184b3370", size = 4317413, upload-time = "2025-10-16T22:16:31.644Z" }, + { url = "https://files.pythonhosted.org/packages/5f/6f/e62b4dfc7ad6518e7eff2516f680d02a0f6eb62c0c212e152ca708a0085e/uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b5b1ac819a3f946d3b2ee07f09149578ae76066d70b44df3fa990add49a82e4", size = 4426307, upload-time = "2025-10-16T22:16:32.917Z" }, + { url = "https://files.pythonhosted.org/packages/90/60/97362554ac21e20e81bcef1150cb2a7e4ffdaf8ea1e5b2e8bf7a053caa18/uvloop-0.22.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e047cc068570bac9866237739607d1313b9253c3051ad84738cbb095be0537b2", size = 4131970, upload-time = "2025-10-16T22:16:34.015Z" }, + { url = "https://files.pythonhosted.org/packages/99/39/6b3f7d234ba3964c428a6e40006340f53ba37993f46ed6e111c6e9141d18/uvloop-0.22.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:512fec6815e2dd45161054592441ef76c830eddaad55c8aa30952e6fe1ed07c0", size = 4296343, upload-time = "2025-10-16T22:16:35.149Z" }, + { url = "https://files.pythonhosted.org/packages/89/8c/182a2a593195bfd39842ea68ebc084e20c850806117213f5a299dfc513d9/uvloop-0.22.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:561577354eb94200d75aca23fbde86ee11be36b00e52a4eaf8f50fb0c86b7705", size = 1358611, upload-time = "2025-10-16T22:16:36.833Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/e301ee96a6dc95224b6f1162cd3312f6d1217be3907b79173b06785f2fe7/uvloop-0.22.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1cdf5192ab3e674ca26da2eada35b288d2fa49fdd0f357a19f0e7c4e7d5077c8", size = 751811, upload-time = "2025-10-16T22:16:38.275Z" }, + { url = "https://files.pythonhosted.org/packages/b7/02/654426ce265ac19e2980bfd9ea6590ca96a56f10c76e63801a2df01c0486/uvloop-0.22.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e2ea3d6190a2968f4a14a23019d3b16870dd2190cd69c8180f7c632d21de68d", size = 4288562, upload-time = "2025-10-16T22:16:39.375Z" }, + { url = "https://files.pythonhosted.org/packages/15/c0/0be24758891ef825f2065cd5db8741aaddabe3e248ee6acc5e8a80f04005/uvloop-0.22.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0530a5fbad9c9e4ee3f2b33b148c6a64d47bbad8000ea63704fa8260f4cf728e", size = 4366890, upload-time = "2025-10-16T22:16:40.547Z" }, + { url = "https://files.pythonhosted.org/packages/d2/53/8369e5219a5855869bcee5f4d317f6da0e2c669aecf0ef7d371e3d084449/uvloop-0.22.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bc5ef13bbc10b5335792360623cc378d52d7e62c2de64660616478c32cd0598e", size = 4119472, upload-time = "2025-10-16T22:16:41.694Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ba/d69adbe699b768f6b29a5eec7b47dd610bd17a69de51b251126a801369ea/uvloop-0.22.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1f38ec5e3f18c8a10ded09742f7fb8de0108796eb673f30ce7762ce1b8550cad", size = 4239051, upload-time = "2025-10-16T22:16:43.224Z" }, + { url = "https://files.pythonhosted.org/packages/90/cd/b62bdeaa429758aee8de8b00ac0dd26593a9de93d302bff3d21439e9791d/uvloop-0.22.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3879b88423ec7e97cd4eba2a443aa26ed4e59b45e6b76aabf13fe2f27023a142", size = 1362067, upload-time = "2025-10-16T22:16:44.503Z" }, + { url = "https://files.pythonhosted.org/packages/0d/f8/a132124dfda0777e489ca86732e85e69afcd1ff7686647000050ba670689/uvloop-0.22.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:4baa86acedf1d62115c1dc6ad1e17134476688f08c6efd8a2ab076e815665c74", size = 752423, upload-time = "2025-10-16T22:16:45.968Z" }, + { url = "https://files.pythonhosted.org/packages/a3/94/94af78c156f88da4b3a733773ad5ba0b164393e357cc4bd0ab2e2677a7d6/uvloop-0.22.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:297c27d8003520596236bdb2335e6b3f649480bd09e00d1e3a99144b691d2a35", size = 4272437, upload-time = "2025-10-16T22:16:47.451Z" }, + { url = "https://files.pythonhosted.org/packages/b5/35/60249e9fd07b32c665192cec7af29e06c7cd96fa1d08b84f012a56a0b38e/uvloop-0.22.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1955d5a1dd43198244d47664a5858082a3239766a839b2102a269aaff7a4e25", size = 4292101, upload-time = "2025-10-16T22:16:49.318Z" }, + { url = "https://files.pythonhosted.org/packages/02/62/67d382dfcb25d0a98ce73c11ed1a6fba5037a1a1d533dcbb7cab033a2636/uvloop-0.22.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b31dc2fccbd42adc73bc4e7cdbae4fc5086cf378979e53ca5d0301838c5682c6", size = 4114158, upload-time = "2025-10-16T22:16:50.517Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7a/f1171b4a882a5d13c8b7576f348acfe6074d72eaf52cccef752f748d4a9f/uvloop-0.22.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:93f617675b2d03af4e72a5333ef89450dfaa5321303ede6e67ba9c9d26878079", size = 4177360, upload-time = "2025-10-16T22:16:52.646Z" }, + { url = "https://files.pythonhosted.org/packages/79/7b/b01414f31546caf0919da80ad57cbfe24c56b151d12af68cee1b04922ca8/uvloop-0.22.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:37554f70528f60cad66945b885eb01f1bb514f132d92b6eeed1c90fd54ed6289", size = 1454790, upload-time = "2025-10-16T22:16:54.355Z" }, + { url = "https://files.pythonhosted.org/packages/d4/31/0bb232318dd838cad3fa8fb0c68c8b40e1145b32025581975e18b11fab40/uvloop-0.22.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:b76324e2dc033a0b2f435f33eb88ff9913c156ef78e153fb210e03c13da746b3", size = 796783, upload-time = "2025-10-16T22:16:55.906Z" }, + { url = "https://files.pythonhosted.org/packages/42/38/c9b09f3271a7a723a5de69f8e237ab8e7803183131bc57c890db0b6bb872/uvloop-0.22.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:badb4d8e58ee08dad957002027830d5c3b06aea446a6a3744483c2b3b745345c", size = 4647548, upload-time = "2025-10-16T22:16:57.008Z" }, + { url = "https://files.pythonhosted.org/packages/c1/37/945b4ca0ac27e3dc4952642d4c900edd030b3da6c9634875af6e13ae80e5/uvloop-0.22.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b91328c72635f6f9e0282e4a57da7470c7350ab1c9f48546c0f2866205349d21", size = 4467065, upload-time = "2025-10-16T22:16:58.206Z" }, + { url = "https://files.pythonhosted.org/packages/97/cc/48d232f33d60e2e2e0b42f4e73455b146b76ebe216487e862700457fbf3c/uvloop-0.22.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:daf620c2995d193449393d6c62131b3fbd40a63bf7b307a1527856ace637fe88", size = 4328384, upload-time = "2025-10-16T22:16:59.36Z" }, + { url = "https://files.pythonhosted.org/packages/e4/16/c1fd27e9549f3c4baf1dc9c20c456cd2f822dbf8de9f463824b0c0357e06/uvloop-0.22.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6cde23eeda1a25c75b2e07d39970f3374105d5eafbaab2a4482be82f272d5a5e", size = 4296730, upload-time = "2025-10-16T22:17:00.744Z" }, +] + [[package]] name = "virtualenv" version = "20.34.0"