From 579d032fd76cf38f4cb243fe7b806546dbf67cc3 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 27 Jun 2025 14:03:39 +0300 Subject: [PATCH 01/28] Handling of topology update push notifications for Standalone Redis client. --- redis/_parsers/base.py | 88 ++++++- redis/_parsers/hiredis.py | 27 +- redis/_parsers/resp3.py | 16 +- redis/client.py | 86 +++++- redis/connection.py | 510 +++++++++++++++++++++++++++++++----- redis/maintenance_events.py | 349 ++++++++++++++++++++++++ 6 files changed, 980 insertions(+), 96 deletions(-) create mode 100644 redis/maintenance_events.py diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 69d7b585dd..a0f6af4ac2 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -3,6 +3,12 @@ from asyncio import IncompleteReadError, StreamReader, TimeoutError from typing import Callable, List, Optional, Protocol, Union +from redis.maintenance_events import ( + NodeMigratedEvent, + NodeMigratingEvent, + NodeMovingEvent, +) + if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: @@ -158,7 +164,19 @@ async def read_response( raise NotImplementedError() -_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] +_INVALIDATION_MESSAGE = (b"invalidate", "invalidate") +_MOVING_MESSAGE = (b"MOVING", "MOVING") +_MIGRATING_MESSAGE = (b"MIGRATING", "MIGRATING") +_MIGRATED_MESSAGE = (b"MIGRATED", "MIGRATED") +_FAILING_OVER_MESSAGE = (b"FAILING_OVER", "FAILING_OVER") +_FAILED_OVER_MESSAGE = (b"FAILED_OVER", "FAILED_OVER") + +_MAINTENANCE_MESSAGES = ( + *_MIGRATING_MESSAGE, + *_MIGRATED_MESSAGE, + *_FAILING_OVER_MESSAGE, + *_FAILED_OVER_MESSAGE, +) class PushNotificationsParser(Protocol): @@ -166,16 +184,41 @@ class PushNotificationsParser(Protocol): pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None def handle_pubsub_push_response(self, response): """Handle pubsub push responses""" raise NotImplementedError() def handle_push_response(self, response, **kwargs): - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + if msg_type in _MOVING_MESSAGE: + host, port = response[2].split(":") + ttl = response[1] + notification = NodeMovingEvent(host, port, ttl) + return self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + ttl = response[1] + notification = NodeMigratingEvent(ttl) + elif msg_type in _MIGRATED_MESSAGE: + notification = NodeMigratedEvent() + else: + notification = None + if notification is not None: + return self.maintenance_push_handler_func(notification) + else: + return None def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -183,12 +226,20 @@ def set_pubsub_push_handler(self, pubsub_push_handler_func): def set_invalidation_push_handler(self, invalidation_push_handler_func): self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class AsyncPushNotificationsParser(Protocol): """Protocol defining async RESP3-specific parsing functionality""" pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None async def handle_pubsub_push_response(self, response): """Handle pubsub push responses asynchronously""" @@ -196,10 +247,31 @@ async def handle_pubsub_push_response(self, response): async def handle_push_response(self, response, **kwargs): """Handle push responses asynchronously""" - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return await self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return await self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # push notification from enterprise cluster for node moving + host, port = response[2].split(":") + ttl = response[1] + id = 1 # TODO: get unique id from push notification + notification = NodeMovingEvent(id, host, port, ttl) + return await self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + ttl = response[1] + id = 1 # TODO: get unique id from push notification + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + id = 1 # TODO: get unique id from push notification + notification = NodeMigratedEvent(id) + return await self.maintenance_push_handler_func(notification) def set_pubsub_push_handler(self, pubsub_push_handler_func): """Set the pubsub push handler function""" @@ -209,6 +281,12 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func): """Set the invalidation push handler function""" self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler_func(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class _AsyncRESPBase(AsyncBaseParser): """Base class for async resp parsing""" diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index 521a58b26c..e9df314a8c 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -47,6 +47,8 @@ def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self._buffer = bytearray(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None self._hiredis_PushNotificationType = None @@ -141,13 +143,15 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response - return response + + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) if disable_decoding: response = self._reader.gets(False) @@ -169,12 +173,13 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + if push_request: return response + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) + elif ( isinstance(response, list) and response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 42c6652e31..72957b464c 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None def handle_pubsub_push_response(self, response): @@ -117,17 +119,21 @@ def _read_response(self, disable_decoding=False, push_request=False): for _ in range(int(response)) ] response = self.handle_push_response(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response + + return self._read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) + return response diff --git a/redis/client.py b/redis/client.py index 0e05b6f542..0ec36c52d9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,4 +1,5 @@ import copy +import logging import re import threading import time @@ -56,6 +57,10 @@ WatchError, ) from redis.lock import Lock +from redis.maintenance_events import ( + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, +) from redis.retry import Retry from redis.utils import ( _set_info_logger, @@ -244,6 +249,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, ) -> None: """ Initialize a new Redis client. @@ -368,6 +374,23 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") + if maintenance_events_config and self.connection_pool.get_protocol() not in [ + 3, + "3", + ]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + if maintenance_events_config and maintenance_events_config.enabled: + self.maintenance_events_pool_handler = MaintenanceEventPoolHandler( + self.connection_pool, maintenance_events_config + ) + self.connection_pool.set_maintenance_events_pool_handler( + self.maintenance_events_pool_handler + ) + else: + self.maintenance_events_pool_handler = None + self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client @@ -565,8 +588,15 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): + maintenance_events_config = ( + None + if self.maintenance_events_pool_handler is None + else self.maintenance_events_pool_handler.config + ) return self.__class__( - connection_pool=self.connection_pool, single_connection_client=True + connection_pool=self.connection_pool, + single_connection_client=True, + maintenance_events_config=maintenance_events_config, ) def __enter__(self): @@ -635,7 +665,14 @@ def _execute_command(self, *args, **options): ), lambda _: self._close_connection(conn), ) + finally: + if conn and conn.should_reconnect(): + logging.debug( + f"***** Redis reconnect before exit _execute_command --> notification for {conn._sock.getpeername()}" + ) + self._close_connection(conn) + conn.connect() if self._single_connection_client: self.single_connection_lock.release() if not self.connection: @@ -686,11 +723,7 @@ def __init__(self, connection_pool): self.connection = self.connection_pool.get_connection() def __enter__(self): - self.connection.send_command("MONITOR") - # check that monitor returns 'OK', but don't return it to user - response = self.connection.read_response() - if not bool_ok(response): - raise RedisError(f"MONITOR failed: {response}") + self._start_monitor() return self def __exit__(self, *args): @@ -700,8 +733,13 @@ def __exit__(self, *args): def next_command(self): """Parse the response from a monitor command""" response = self.connection.read_response() + + if response is None: + return None + if isinstance(response, bytes): response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) m = self.monitor_re.match(command_data) db_id, client_info, command = m.groups() @@ -737,6 +775,14 @@ def listen(self): while True: yield self.next_command() + def _start_monitor(self): + self.connection.send_command("MONITOR") + # check that monitor returns 'OK', but don't return it to user + response = self.connection.read_response() + + if not bool_ok(response): + raise RedisError(f"MONITOR failed: {response}") + class PubSub: """ @@ -881,7 +927,7 @@ def clean_health_check_responses(self) -> None: """ ttl = 10 conn = self.connection - while self.health_check_response_counter > 0 and ttl > 0: + while conn and self.health_check_response_counter > 0 and ttl > 0: if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): response = self._execute(conn, conn.read_response) if self.is_health_check_response(response): @@ -911,10 +957,18 @@ def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return conn.retry.call_with_retry( + + response = conn.retry.call_with_retry( lambda: command(*args, **kwargs), lambda _: self._reconnect(conn), ) + if conn.should_reconnect(): + logging.debug( + f"***** PubSub --> Reconnect on notification for {conn._sock.getpeername()}" + ) + self._reconnect(conn) + + return response def parse_response(self, block=True, timeout=0): """Parse the response from a publish/subscribe command""" @@ -1148,6 +1202,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): return None if isinstance(response, bytes): response = [b"pong", response] if response != b"PONG" else [b"pong", b""] + message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { @@ -1351,6 +1406,7 @@ def reset(self) -> None: # clean up the other instance attributes self.watching = False self.explicit_transaction = False + # we can safely return the connection to the pool here since we're # sure we're no longer WATCHing anything if self.connection: @@ -1510,6 +1566,7 @@ def _execute_transaction( if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) data.append(r) + return data def _execute_pipeline(self, connection, commands, raise_on_error): @@ -1517,16 +1574,17 @@ def _execute_pipeline(self, connection, commands, raise_on_error): all_cmds = connection.pack_commands([args for args, _ in commands]) connection.send_packed_command(all_cmds) - response = [] + responses = [] for args, options in commands: try: - response.append(self.parse_response(connection, args[0], **options)) + responses.append(self.parse_response(connection, args[0], **options)) except ResponseError as e: - response.append(e) + responses.append(e) if raise_on_error: - self.raise_first_error(commands, response) - return response + self.raise_first_error(commands, responses) + + return responses def raise_first_error(self, commands, response): for i, r in enumerate(response): @@ -1611,6 +1669,8 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: + # in reset() the connection is diconnected before returned to the pool if + # it is marked for reconnect. self.reset() def discard(self): diff --git a/redis/connection.py b/redis/connection.py index 47cb589569..f55e1b455c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,4 +1,5 @@ import copy +import logging import os import socket import sys @@ -19,10 +20,11 @@ CacheInterface, CacheKey, ) +from redis.typing import Number from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .auth.token import TokenInterface -from .backoff import NoBackoff +from .backoff import ExponentialWithJitterBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher from .exceptions import ( @@ -36,6 +38,11 @@ ResponseError, TimeoutError, ) +from .maintenance_events import ( + MaintenanceEventConnectionHandler, + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, +) from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -159,6 +166,10 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass + @abstractmethod + def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler): + pass + @abstractmethod def get_protocol(self): pass @@ -222,6 +233,26 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @abstractmethod + def mark_for_reconnect(self): + pass + + @abstractmethod + def should_reconnect(self): + pass + + @abstractmethod + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + pass + + @abstractmethod + def update_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + pass + class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" @@ -250,6 +281,10 @@ def __init__( protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = -1, ): """ Initialize a new Connection. @@ -288,16 +323,15 @@ def __init__( # Add TimeoutError to the errors list to retry on retry_on_error.append(TimeoutError) self.retry_on_error = retry_on_error - if retry or retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) + if retry is None: + self.retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) else: - self.retry = Retry(NoBackoff(), 0) + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + if retry_on_error: + self.retry.update_supported_errors(retry_on_error) self.health_check_interval = health_check_interval self.next_health_check = 0 self.redis_connect_func = redis_connect_func @@ -305,7 +339,6 @@ def __init__( self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size - self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 self._re_auth_token: Optional[TokenInterface] = None @@ -320,7 +353,26 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") # p = DEFAULT_RESP_VERSION self.protocol = p + if self.protocol == 3 and parser_class == DefaultParser: + parser_class = _RESP3Parser + self.set_parser(parser_class) + + if maintenance_events_config and maintenance_events_config.enabled: + if maintenance_events_pool_handler: + self._parser.set_node_moving_push_handler( + maintenance_events_pool_handler.handle_event + ) + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler(self, maintenance_events_config) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + self._command_packer = self._construct_command_packer(command_packer) + self._should_reconnect = False + self.tmp_host_address = tmp_host_address + self.tmp_relax_timeout = tmp_relax_timeout def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -375,6 +427,11 @@ def set_parser(self, parser_class): """ self._parser = parser_class(socket_read_size=self._socket_read_size) + def set_maintenance_event_pool_handler( + self, maintenance_event_pool_handler: MaintenanceEventPoolHandler + ): + self._parser.set_node_moving_push_handler(maintenance_event_pool_handler) + def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) @@ -549,6 +606,8 @@ def disconnect(self, *args): conn_sock = self._sock self._sock = None + # reset the reconnect flag + self._should_reconnect = False if conn_sock is None: return @@ -626,6 +685,7 @@ def can_read(self, timeout=0): try: return self._parser.can_read(timeout) + except OSError as e: self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") @@ -732,6 +792,35 @@ def re_auth(self): self.read_response() self._re_auth_token = None + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + if self._sock: + timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout + logging.debug( + f"***** Connection --> Updating timeout for {self._sock.getpeername()}" + f" to timeout {timeout}; relax_timeout: {relax_timeout}" + ) + self._sock.settimeout(timeout) + self._parser._buffer.socket_timeout = timeout + + def update_tmp_settings( + self, + tmp_host_address: Optional[str | object] = SENTINEL, + tmp_relax_timeout: Optional[float | object] = SENTINEL, + ): + """ + The value of SENTINEL is used to indicate that the property should not be updated. + """ + if tmp_host_address is not SENTINEL: + self.tmp_host_address = tmp_host_address + if tmp_relax_timeout is not SENTINEL: + self.tmp_relax_timeout = tmp_relax_timeout + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -764,8 +853,14 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None + if self.tmp_host_address is not None: + logging.debug( + f"***** Connection --> Using tmp_host_address: {self.tmp_host_address}" + ) + host = self.tmp_host_address or self.host + for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM + host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -781,13 +876,32 @@ def _connect(self): sock.setsockopt(socket.IPPROTO_TCP, k, v) # set the socket_connect_timeout before we connect - sock.settimeout(self.socket_connect_timeout) + if self.tmp_relax_timeout != -1: + logging.debug( + f"***** Connection connect --> Using relax_timeout: {self.tmp_relax_timeout}" + ) + sock.settimeout(self.tmp_relax_timeout) + else: + logging.debug( + f"***** Connection connect --> Using default socket_connect_timeout: {self.socket_connect_timeout}" + ) + sock.settimeout(self.socket_connect_timeout) # connect sock.connect(socket_address) # set the socket_timeout now that we're connected - sock.settimeout(self.socket_timeout) + if self.tmp_relax_timeout != -1: + logging.debug( + f"***** Connection --> Using relax_timeout: {self.tmp_relax_timeout}" + ) + sock.settimeout(self.tmp_relax_timeout) + else: + logging.debug( + f"***** Connection --> Using default socket_timeout: {self.socket_timeout}" + ) + sock.settimeout(self.socket_timeout) + logging.debug(f"Connected to {sock.getpeername()}") return sock except OSError as _: @@ -1415,6 +1529,14 @@ def __init__( connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) + if connection_kwargs.get( + "maintenance_events_pool_handler" + ) or connection_kwargs.get("maintenance_events_config"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -1449,6 +1571,46 @@ def get_protocol(self): """ return self.connection_kwargs.get("protocol", None) + def maintenance_events_pool_handler_enabled(self): + """ + Returns: + True if the maintenance events pool handler is enabled, False otherwise. + """ + maintenance_events_config = self.connection_kwargs.get( + "maintenance_events_config", False + ) + + return maintenance_events_config and maintenance_events_config.enabled + + def set_maintenance_events_pool_handler( + self, maintenance_events_pool_handler: MaintenanceEventPoolHandler + ): + self.connection_kwargs.update( + { + "maintenance_events_pool_handler": maintenance_events_pool_handler, + "maintenance_events_config": maintenance_events_pool_handler.config, + } + ) + + self._update_maintenance_events_configs_for_connections( + maintenance_events_pool_handler + ) + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_events_pool_handler( + maintenance_events_pool_handler + ) + conn.maintenance_events_config = maintenance_events_pool_handler.config + for conn in self._in_use_connections: + conn.set_maintenance_events_pool_handler( + maintenance_events_pool_handler + ) + conn.maintenance_events_config = maintenance_events_pool_handler.config + def reset(self) -> None: self._created_connections = 0 self._available_connections = [] @@ -1536,7 +1698,11 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read() and self.cache is None: + if ( + connection.can_read() + and self.cache is None + and not self.maintenance_events_pool_handler_enabled() + ): raise ConnectionError("Connection has data") except (ConnectionError, TimeoutError, OSError): connection.disconnect() @@ -1548,7 +1714,6 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # leak it self.release(connection) raise - return connection def get_encoder(self) -> Encoder: @@ -1570,7 +1735,6 @@ def make_connection(self) -> "ConnectionInterface": return CacheProxyConnection( self.connection_class(**self.connection_kwargs), self.cache, self._lock ) - return self.connection_class(**self.connection_kwargs) def release(self, connection: "Connection") -> None: @@ -1585,6 +1749,11 @@ def release(self, connection: "Connection") -> None: return if self.owns_connection(connection): + if connection.should_reconnect(): + logging.debug( + f"***** Pool--> disconnecting in release {connection._sock.getpeername()}" + ) + connection.disconnect() self._available_connections.append(connection) self._event_dispatcher.dispatch( AfterConnectionReleasedEvent(connection) @@ -1646,6 +1815,154 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) + def update_connection_kwargs_with_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Update the connection kwargs with the temporary host address and the + relax timeout(if enabled). + This is used when a cluster node is rebind to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + This new address will be used to create new connections until the old node is decomissioned. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + If -1 is provided - the relax timeout is disabled, so the tmp property is not set + """ + self.connection_kwargs.update({"tmp_host_address": tmp_host_address}) + self.connection_kwargs.update({"tmp_relax_timeout": tmp_relax_timeout}) + + def update_connections_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Update the tmp settings for all connections in the pool. + This is used when a cluster node is rebind to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + with self._lock: + for conn in self._available_connections: + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + for conn in self._in_use_connections: + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_active_connections_for_reconnect( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + """ + for conn in self._in_use_connections: + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + + for conn in self._available_connections: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout( + self, + relax_timeout: Optional[float], + include_available_connections: bool = False, + ): + """ + Update the timeout either for all connections in the pool or just for the ones in use. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param relax_timeout: The relax timeout to use for the connection. + If -1 is provided - the relax timeout is disabled. + :param include_available_connections: Whether to include available connections in the update. + """ + logging.debug(f"***** Pool --> Updating timeouts. New value: {relax_timeout}") + start_time = time.time() + + for conn in self._in_use_connections: + self._update_connection_timeout(conn, relax_timeout) + + if include_available_connections: + for conn in self._available_connections: + self._update_connection_timeout(conn, relax_timeout) + + execution_time_us = (time.time() - start_time) * 1000000 + logging.error( + f"###### TIMEOUTS execution time: {execution_time_us:.0f} microseconds" + ) + + def _update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.mark_for_reconnect() + self._update_connection_tmp_settings( + connection, tmp_host_address, tmp_relax_timeout + ) + + def _disconnect_and_update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.disconnect() + self._update_connection_tmp_settings( + connection, tmp_host_address, tmp_relax_timeout + ) + + def _update_connection_tmp_settings( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.tmp_host_address = tmp_host_address + connection.tmp_relax_timeout = tmp_relax_timeout + + def _update_connection_timeout( + self, connection: "Connection", relax_timeout: Optional[Number] + ): + connection.update_current_socket_timeout(relax_timeout) + async def _mock(self, error: RedisError): """ Dummy functions, needs to be passed as error callback to retry object. @@ -1707,16 +2024,17 @@ def __init__( def reset(self): # Create and fill up a thread safe queue with ``None`` values. - self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except Full: - break + with self._lock: + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -1731,14 +2049,18 @@ def reset(self): def make_connection(self): "Make a fresh connection." - if self.cache is not None: - connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock - ) - else: - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + with self._lock: + if self.cache is not None: + connection = CacheProxyConnection( + self.connection_class(**self.connection_kwargs), + self.cache, + self._lock, + ) + else: + connection = self.connection_class(**self.connection_kwargs) + + self._connections.append(connection) + return connection @deprecated_args( args_to_warn=["*"], @@ -1763,17 +2085,18 @@ def get_connection(self, command_name=None, *keys, **options): # Try and get a connection from the pool. If one isn't available within # self.timeout then raise a ``ConnectionError``. connection = None - try: - connection = self.pool.get(block=True, timeout=self.timeout) - except Empty: - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() + with self._lock: + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() try: # ensure this connection is connected to Redis @@ -1801,25 +2124,88 @@ def release(self, connection): "Releases the connection back to the pool." # Make sure we haven't changed process. self._checkpid() - if not self.owns_connection(connection): - # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. - connection.disconnect() - self.pool.put_nowait(None) - return - # Put the connection back into the pool. - try: - self.pool.put_nowait(connection) - except Full: - # perhaps the pool has been reset() after a fork? regardless, - # we don't want this connection - pass + with self._lock: + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + connection.disconnect() + self.pool.put_nowait(None) + return + if connection.should_reconnect(): + logging.debug( + f"***** Blocking Pool--> disconnecting in release {connection._sock.getpeername()}" + ) + connection.disconnect() + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass def disconnect(self): "Disconnects all connections in the pool." self._checkpid() - for connection in self._connections: - connection.disconnect() + with self._lock: + for connection in self._connections: + connection.disconnect() + + def update_active_connections_for_reconnect( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + with self._lock: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if tmp_relax_timeout != -1: + conn.update_socket_timeout(tmp_relax_timeout) + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[Number] = None, + ): + with self._lock: + existing_connections = self.pool.queue + + for conn in existing_connections: + if conn: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout(self, relax_timeout: Optional[float] = None): + logging.debug( + f"***** Blocking Pool --> Updating timeouts. relax_timeout: {relax_timeout}" + ) + + with self._lock: + for conn in tuple(self._connections): + self._update_connection_timeout(conn, relax_timeout) + + def update_connections_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + with self._lock: + for conn in tuple(self._connections): + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + + def _update_maintenance_events_config_for_connections( + self, maintenance_events_config + ): + with self._lock: + for conn in tuple(self._connections): + conn.maintenance_events_config = maintenance_events_config diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py new file mode 100644 index 0000000000..bbc519d0cc --- /dev/null +++ b/redis/maintenance_events.py @@ -0,0 +1,349 @@ +import logging +import threading +import time +from typing import TYPE_CHECKING + +from redis.typing import Number + +if TYPE_CHECKING: + from redis.connection import ConnectionInterface, ConnectionPool + + +class MaintenanceEvent: + """ + Base class for maintenance events sent through push messages by Redis server. + + This class provides common TTL (Time-To-Live) functionality for all + maintenance events. + + Attributes: + ttl (int): Time-to-live in seconds for this notification + creation_time (float): Timestamp when the notification was created/read + """ + + def __init__(self, ttl: int): + """ + Initialize a new MaintenanceEvent with TTL functionality. + + Args: + ttl (int): Time-to-live in seconds for this notification + """ + self.ttl = ttl + self.creation_time = int(time.time()) + self.expire_at = self.creation_time + self.ttl + + def is_expired(self) -> bool: + """ + Check if this event has expired based on its TTL + and creation time. + + Returns: + bool: True if the event has expired, False otherwise + """ + return int(time.time()) > (self.creation_time + self.ttl) + + +class NodeMovingEvent(MaintenanceEvent): + """ + This event is received when a node is replaced with a new node + during cluster rebalancing or maintenance operations. + """ + + def __init__(self, new_node_host: str, new_node_port: int, ttl: int): + """ + Initialize a new NodeMovingEvent. + + Args: + new_node_host (str): Hostname or IP address of the new replacement node + new_node_port (int): Port number of the new replacement node + ttl (int): Time-to-live in seconds for this notification + """ + super().__init__(ttl) + self.new_node_host = new_node_host + self.new_node_port = new_node_port + + def __repr__(self) -> str: + expiry_time = self.expire_at + remaining = max(0, expiry_time - time.time()) + + return ( + f"{self.__class__.__name__}(" + f"new_node_host='{self.new_node_host}', " + f"new_node_port={self.new_node_port}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMovingEvent events are considered equal if they have the same + new_node_host and new_node_port. + """ + if not isinstance(other, NodeMovingEvent): + return False + return ( + self.new_node_host == other.new_node_host + and self.new_node_port == other.new_node_port + ) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on new_node_host and new_node_port + """ + return hash((self.__class__, self.new_node_host, self.new_node_port)) + + +class NodeMigratingEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of migrating slots. + + This event is received when a node starts migrating its slots to another node + during cluster rebalancing or maintenance operations. + + Args: + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, ttl: int): + super().__init__(ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.time()) + return ( + f"{self.__class__.__name__}(" + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + +class NodeMigratedEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed migrating slots. + + This event is received when a node has finished migrating all its slots + to other nodes during cluster rebalancing or maintenance operations. + + Args: + ttl (int): Time-to-live in seconds for this notification + """ + + DEFAULT_TTL = 5 + + def __init__(self): + super().__init__(NodeMigratedEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.time()) + return ( + f"{self.__class__.__name__}(" + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + +class MaintenanceEventsConfig: + """ + Configuration class for maintenance events handling behaviour. Events are received through + push notifications. + + This class defines how the Redis client should react to different push notifications + such as node moving, migrations, etc. in a Redis cluster. + + """ + + def __init__( + self, + enabled: bool = False, + proactive_reconnect: bool = True, + relax_timeout: Number = 20, + ): + """ + Initialize a new MaintenanceEventsConfig. + + Args: + enabled (bool): Whether to enable maintenance events handling. + Defaults to False. + proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. + Defaults to True. + relax_timeout (Number): The relax timeout to use for the connection during maintenance. + If -1 is provided - the relax timeout is disabled. Defaults to 20. + + """ + self.enabled = enabled + self.relax_timeout = relax_timeout + self.proactive_reconnect = proactive_reconnect + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"enabled={self.enabled}, " + f"proactive_reconnect={self.proactive_reconnect}, " + f"relax_timeout={self.relax_timeout}, " + f")" + ) + + def is_relax_timeouts_enabled(self) -> bool: + """ + Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout. + If relax_timeout is set to None, it will make the operation blocking + and waiting until any response is received. + + Returns: + True if the relax_timeout is enabled, False otherwise. + """ + return self.relax_timeout != -1 + + +class MaintenanceEventPoolHandler: + def __init__(self, pool: "ConnectionPool", config: MaintenanceEventsConfig) -> None: + self.pool = pool + self.config = config + self._processed_events = set() + self._lock = threading.RLock() + + def remove_expired_notifications(self): + with self._lock: + for notification in tuple(self._processed_events): + if notification.is_expired(): + self._processed_events.remove(notification) + + def handle_event(self, notification: MaintenanceEvent): + self.remove_expired_notifications() + + if isinstance(notification, NodeMovingEvent): + return self.handle_node_moving_event(notification) + else: + logging.error(f"Unhandled notification type: {notification}") + + def handle_node_moved_event(self): + with self._lock: + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + with self.pool._lock: + if self.config.is_relax_timeouts_enabled(): + # reset the timeout for existing connections + self.pool.update_connections_current_timeout( + relax_timeout=-1, include_available_connections=True + ) + logging.debug("***** MOVING END--> TIMEOUTS RESET") + + self.pool.update_connections_tmp_settings( + tmp_host_address=None, tmp_relax_timeout=-1 + ) + logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") + + def handle_node_moving_event(self, event: NodeMovingEvent): + if ( + not self.config.proactive_reconnect + and not self.config.is_relax_timeouts_enabled() + ): + return + with self._lock: + if event in self._processed_events: + # nothing to do in the connection pool handling + # the event has already been handled or is expired + # just return + logging.debug("***** MOVING --> SKIPPED DONE") + return + + logging.info(f"***** MOVING --> {event}") + logging.info(f"***** MOVING --> set: {self._processed_events}") + start_time = time.time() + + with self.pool._lock: + if ( + self.config.proactive_reconnect + or self.config.is_relax_timeouts_enabled() + ): + # edit the config for new connections until the notification expires + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + if self.config.is_relax_timeouts_enabled(): + # extend the timeout for all connections that are currently in use + self.pool.update_connections_current_timeout( + self.config.relax_timeout + ) + if self.config.proactive_reconnect: + # take care for the active connections in the pool + # mark them for reconnect after they complete the current command + self.pool.update_active_connections_for_reconnect( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + + # take care for the inactive connections in the pool + # delete them and create new ones + start_time_2 = time.time() + self.pool.disconnect_and_reconfigure_free_connections( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + execution_time_us = (time.time() - start_time_2) * 1000000 + logging.error( + f"###### MOVING disconnects execution time: {execution_time_us:.0f} microseconds" + ) + + threading.Timer(event.ttl, self.handle_node_moved_event).start() + + self._processed_events.add(event) + execution_time_us = (time.time() - start_time) * 1000000 + logging.error( + f"###### MOVING total execution time: {execution_time_us:.0f} microseconds" + ) + + +class MaintenanceEventConnectionHandler: + def __init__( + self, connection: "ConnectionInterface", config: MaintenanceEventsConfig + ) -> None: + self.connection = connection + self.config = config + + def handle_event(self, event: MaintenanceEvent): + if isinstance(event, NodeMigratingEvent): + return self.handle_migrating_event(event) + elif isinstance(event, NodeMigratedEvent): + return self.handle_migration_completed_event(event) + else: + logging.error(f"Unhandled event type: {event}") + + def handle_migrating_event(self, notification: NodeMigratingEvent): + if not self.config.is_relax_timeouts_enabled(): + return + + logging.info(f"***** MIGRATING --> {notification}") + # extend the timeout for all created connections + self.connection.update_current_socket_timeout(self.config.relax_timeout) + self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) + + def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + if not self.config.is_relax_timeouts_enabled(): + return + + logging.info(f"***** MIGRATED --> {notification}") + # Node migration completed - reset the connection + # timeouts by providing -1 as the relax timeout + self.connection.update_current_socket_timeout(-1) + self.connection.update_tmp_settings(tmp_relax_timeout=-1) From 8d27a8685e6bbbac1ceb1088c8a7211c19f2d8bc Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 11:07:50 +0300 Subject: [PATCH 02/28] Adding sequence id to the maintenance push notifications. Adding unit tests for maintenance_events.py file --- redis/_parsers/base.py | 15 +- redis/maintenance_events.py | 136 ++++++-- tests/test_maintenance_events.py | 543 +++++++++++++++++++++++++++++++ 3 files changed, 665 insertions(+), 29 deletions(-) create mode 100644 tests/test_maintenance_events.py diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index a0f6af4ac2..aa5a6b0f12 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -205,14 +205,17 @@ def handle_push_response(self, response, **kwargs): if msg_type in _MOVING_MESSAGE: host, port = response[2].split(":") ttl = response[1] - notification = NodeMovingEvent(host, port, ttl) + id = 1 # Hardcoded value for sync parser + notification = NodeMovingEvent(id, host, port, ttl) return self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: ttl = response[1] - notification = NodeMigratingEvent(ttl) + id = 2 # Hardcoded value for sync parser + notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: - notification = NodeMigratedEvent() + id = 3 # Hardcoded value for sync parser + notification = NodeMigratedEvent(id) else: notification = None if notification is not None: @@ -260,16 +263,16 @@ async def handle_push_response(self, response, **kwargs): # push notification from enterprise cluster for node moving host, port = response[2].split(":") ttl = response[1] - id = 1 # TODO: get unique id from push notification + id = 1 # Hardcoded value for async parser notification = NodeMovingEvent(id, host, port, ttl) return await self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: ttl = response[1] - id = 1 # TODO: get unique id from push notification + id = 2 # Hardcoded value for async parser notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: - id = 1 # TODO: get unique id from push notification + id = 3 # Hardcoded value for async parser notification = NodeMigratedEvent(id) return await self.maintenance_push_handler_func(notification) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index bbc519d0cc..d818a846b8 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -1,7 +1,8 @@ import logging import threading import time -from typing import TYPE_CHECKING +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional from redis.typing import Number @@ -9,27 +10,30 @@ from redis.connection import ConnectionInterface, ConnectionPool -class MaintenanceEvent: +class MaintenanceEvent(ABC): """ Base class for maintenance events sent through push messages by Redis server. - This class provides common TTL (Time-To-Live) functionality for all - maintenance events. + This class provides common functionality for all maintenance events including + unique identification and TTL (Time-To-Live) functionality. Attributes: + id (int): Unique identifier for this event ttl (int): Time-to-live in seconds for this notification creation_time (float): Timestamp when the notification was created/read """ - def __init__(self, ttl: int): + def __init__(self, id: int, ttl: int): """ - Initialize a new MaintenanceEvent with TTL functionality. + Initialize a new MaintenanceEvent with unique ID and TTL functionality. Args: + id (int): Unique identifier for this event ttl (int): Time-to-live in seconds for this notification """ + self.id = id self.ttl = ttl - self.creation_time = int(time.time()) + self.creation_time = time.monotonic() self.expire_at = self.creation_time + self.ttl def is_expired(self) -> bool: @@ -40,7 +44,49 @@ def is_expired(self) -> bool: Returns: bool: True if the event has expired, False otherwise """ - return int(time.time()) > (self.creation_time + self.ttl) + return time.monotonic() > (self.creation_time + self.ttl) + + @abstractmethod + def __repr__(self) -> str: + """ + Return a string representation of the maintenance event. + + This method must be implemented by all concrete subclasses. + + Returns: + str: String representation of the event + """ + pass + + @abstractmethod + def __eq__(self, other) -> bool: + """ + Compare two maintenance events for equality. + + This method must be implemented by all concrete subclasses. + Events are typically considered equal if they have the same id + and are of the same type. + + Args: + other: The other object to compare with + + Returns: + bool: True if the events are equal, False otherwise + """ + pass + + @abstractmethod + def __hash__(self) -> int: + """ + Return a hash value for the maintenance event. + + This method must be implemented by all concrete subclasses to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value for the event + """ + pass class NodeMovingEvent(MaintenanceEvent): @@ -49,25 +95,27 @@ class NodeMovingEvent(MaintenanceEvent): during cluster rebalancing or maintenance operations. """ - def __init__(self, new_node_host: str, new_node_port: int, ttl: int): + def __init__(self, id: int, new_node_host: str, new_node_port: int, ttl: int): """ Initialize a new NodeMovingEvent. Args: + id (int): Unique identifier for this event new_node_host (str): Hostname or IP address of the new replacement node new_node_port (int): Port number of the new replacement node ttl (int): Time-to-live in seconds for this notification """ - super().__init__(ttl) + super().__init__(id, ttl) self.new_node_host = new_node_host self.new_node_port = new_node_port def __repr__(self) -> str: expiry_time = self.expire_at - remaining = max(0, expiry_time - time.time()) + remaining = max(0, expiry_time - time.monotonic()) return ( f"{self.__class__.__name__}(" + f"id={self.id}, " f"new_node_host='{self.new_node_host}', " f"new_node_port={self.new_node_port}, " f"ttl={self.ttl}, " @@ -81,12 +129,13 @@ def __repr__(self) -> str: def __eq__(self, other) -> bool: """ Two NodeMovingEvent events are considered equal if they have the same - new_node_host and new_node_port. + id, new_node_host, and new_node_port. """ if not isinstance(other, NodeMovingEvent): return False return ( - self.new_node_host == other.new_node_host + self.id == other.id + and self.new_node_host == other.new_node_host and self.new_node_port == other.new_node_port ) @@ -96,9 +145,9 @@ def __hash__(self) -> int: instances to be used in sets and as dictionary keys. Returns: - int: Hash value based on new_node_host and new_node_port + int: Hash value based on event type, id, new_node_host, and new_node_port """ - return hash((self.__class__, self.new_node_host, self.new_node_port)) + return hash((self.__class__, self.id, self.new_node_host, self.new_node_port)) class NodeMigratingEvent(MaintenanceEvent): @@ -109,17 +158,19 @@ class NodeMigratingEvent(MaintenanceEvent): during cluster rebalancing or maintenance operations. Args: + id (int): Unique identifier for this event ttl (int): Time-to-live in seconds for this notification """ - def __init__(self, ttl: int): - super().__init__(ttl) + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) def __repr__(self) -> str: expiry_time = self.creation_time + self.ttl - remaining = max(0, expiry_time - time.time()) + remaining = max(0, expiry_time - time.monotonic()) return ( f"{self.__class__.__name__}(" + f"id={self.id}, " f"ttl={self.ttl}, " f"creation_time={self.creation_time}, " f"expires_at={expiry_time}, " @@ -128,6 +179,25 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other) -> bool: + """ + Two NodeMigratingEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratingEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + class NodeMigratedEvent(MaintenanceEvent): """ @@ -137,19 +207,20 @@ class NodeMigratedEvent(MaintenanceEvent): to other nodes during cluster rebalancing or maintenance operations. Args: - ttl (int): Time-to-live in seconds for this notification + id (int): Unique identifier for this event """ DEFAULT_TTL = 5 - def __init__(self): - super().__init__(NodeMigratedEvent.DEFAULT_TTL) + def __init__(self, id: int): + super().__init__(id, NodeMigratedEvent.DEFAULT_TTL) def __repr__(self) -> str: expiry_time = self.creation_time + self.ttl - remaining = max(0, expiry_time - time.time()) + remaining = max(0, expiry_time - time.monotonic()) return ( f"{self.__class__.__name__}(" + f"id={self.id}, " f"ttl={self.ttl}, " f"creation_time={self.creation_time}, " f"expires_at={expiry_time}, " @@ -158,6 +229,25 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other) -> bool: + """ + Two NodeMigratedEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratedEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + class MaintenanceEventsConfig: """ @@ -173,7 +263,7 @@ def __init__( self, enabled: bool = False, proactive_reconnect: bool = True, - relax_timeout: Number = 20, + relax_timeout: Optional[Number] = 20, ): """ Initialize a new MaintenanceEventsConfig. diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py new file mode 100644 index 0000000000..69a6014fe1 --- /dev/null +++ b/tests/test_maintenance_events.py @@ -0,0 +1,543 @@ +import threading +from unittest.mock import Mock, patch + +from redis.maintenance_events import ( + MaintenanceEvent, + NodeMovingEvent, + NodeMigratingEvent, + NodeMigratedEvent, + MaintenanceEventsConfig, + MaintenanceEventPoolHandler, + MaintenanceEventConnectionHandler, +) + + +class TestMaintenanceEvent: + """Test the base MaintenanceEvent class functionality through concrete subclasses.""" + + def test_abstract_class_cannot_be_instantiated(self): + """Test that MaintenanceEvent cannot be instantiated directly.""" + import pytest + + with patch("time.monotonic", return_value=1000): + with pytest.raises(TypeError): + MaintenanceEvent(id=1, ttl=10) # type: ignore + + def test_init_through_subclass(self): + """Test MaintenanceEvent initialization through concrete subclass.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.ttl == 10 + assert event.creation_time == 1000 + assert event.expire_at == 1010 + + def test_is_expired_false(self): + """Test is_expired returns False for non-expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + assert not event.is_expired() + + def test_is_expired_true(self): + """Test is_expired returns True for expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1015): # 15 seconds later + assert event.is_expired() + + def test_is_expired_exact_boundary(self): + """Test is_expired at exact expiration boundary.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1010): # Exactly at expiration + assert not event.is_expired() + + with patch("time.monotonic", return_value=1011): # 1 second past expiration + assert event.is_expired() + + +class TestNodeMovingEvent: + """Test the NodeMovingEvent class.""" + + def test_init(self): + """Test NodeMovingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.new_node_host == "localhost" + assert event.new_node_port == 6379 + assert event.ttl == 10 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMovingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + repr_str = repr(event) + assert "NodeMovingEvent" in repr_str + assert "id=1" in repr_str + assert "new_node_host='localhost'" in repr_str + assert "new_node_port=6379" in repr_str + assert "ttl=10" in repr_str + assert "remaining=5.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_same_id_host_port(self): + """Test equality for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert event1 == event2 + + def test_equality_same_id_different_host(self): + """Test inequality for events with same id but different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_same_id_different_port(self): + """Test inequality for events with same id but different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_id(self): + """Test inequality for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_type(self): + """Test inequality for events of different types.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMigratingEvent(id=1, ttl=10) + assert event1 != event2 + + def test_hash_same_id_host_port(self): + """Test hash consistency for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert hash(event1) == hash(event2) + + def test_hash_different_host(self): + """Test hash difference for events with different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_port(self): + """Test hash difference for events with different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_id(self): + """Test hash difference for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_set_functionality(self): + """Test that events can be used in sets correctly.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Same id, host, port - should be considered the same + event3 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6380, ttl=10 + ) # Same id but different host/port - should be different + event4 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) # Different id - should be different + + event_set = {event1, event2, event3, event4} + assert len(event_set) == 3 # event1 and event2 should be considered the same + + +class TestNodeMigratingEvent: + """Test the NodeMigratingEvent class.""" + + def test_init(self): + """Test NodeMigratingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMigratingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeMigratingEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratingEvent.""" + event1 = NodeMigratingEvent(id=1, ttl=5) + event2 = NodeMigratingEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeMigratingEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeMigratedEvent: + """Test the NodeMigratedEvent class.""" + + def test_init(self): + """Test NodeMigratedEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeMigratedEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeMigratedEvent.DEFAULT_TTL == 5 + event = NodeMigratedEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeMigratedEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeMigratedEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratedEvent.""" + event1 = NodeMigratedEvent(id=1) + event2 = NodeMigratedEvent(id=1) # Same id + event3 = NodeMigratedEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestMaintenanceEventsConfig: + """Test the MaintenanceEventsConfig class.""" + + def test_init_defaults(self): + """Test MaintenanceEventsConfig initialization with defaults.""" + config = MaintenanceEventsConfig() + assert config.enabled is False + assert config.proactive_reconnect is True + assert config.relax_timeout == 20 + + def test_init_custom_values(self): + """Test MaintenanceEventsConfig initialization with custom values.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + assert config.enabled is True + assert config.proactive_reconnect is False + assert config.relax_timeout == 30 + + def test_repr(self): + """Test MaintenanceEventsConfig string representation.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + repr_str = repr(config) + assert "MaintenanceEventsConfig" in repr_str + assert "enabled=True" in repr_str + assert "proactive_reconnect=False" in repr_str + assert "relax_timeout=30" in repr_str + + def test_is_relax_timeouts_enabled_true(self): + """Test is_relax_timeouts_enabled returns True for positive timeout.""" + config = MaintenanceEventsConfig(relax_timeout=20) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_false(self): + """Test is_relax_timeouts_enabled returns False for -1 timeout.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + assert config.is_relax_timeouts_enabled() is False + + def test_is_relax_timeouts_enabled_zero(self): + """Test is_relax_timeouts_enabled returns True for zero timeout.""" + config = MaintenanceEventsConfig(relax_timeout=0) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_none(self): + """Test is_relax_timeouts_enabled returns True for None timeout.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.is_relax_timeouts_enabled() is True + + def test_relax_timeout_none_is_saved_as_none(self): + """Test that None value for relax_timeout is saved as None.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.relax_timeout is None + + +class TestMaintenanceEventPoolHandler: + """Test the MaintenanceEventPoolHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_pool = Mock() + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=20 + ) + self.handler = MaintenanceEventPoolHandler(self.mock_pool, self.config) + + def test_init(self): + """Test MaintenanceEventPoolHandler initialization.""" + assert self.handler.pool == self.mock_pool + assert self.handler.config == self.config + assert isinstance(self.handler._processed_events, set) + assert isinstance(self.handler._lock, type(threading.RLock())) + + def test_remove_expired_notifications(self): + """Test removal of expired notifications.""" + with patch("time.monotonic", return_value=1000): + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="host2", new_node_port=6380, ttl=5 + ) + self.handler._processed_events.add(event1) + self.handler._processed_events.add(event2) + + # Move time forward but not enough to expire event2 (expires at 1005) + with patch("time.monotonic", return_value=1003): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 in self.handler._processed_events # Not expired yet + + # Move time forward to expire event2 but not event1 + with patch("time.monotonic", return_value=1006): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 not in self.handler._processed_events # Now expired + + def test_handle_event_node_moving(self): + """Test handling of NodeMovingEvent.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch.object(self.handler, "handle_node_moving_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMigratingEvent(id=1, ttl=5) # Not handled by pool handler + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_node_moving_event_disabled_config(self): + """Test node moving event handling when both features are disabled.""" + config = MaintenanceEventsConfig(proactive_reconnect=False, relax_timeout=-1) + handler = MaintenanceEventPoolHandler(self.mock_pool, config) + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = handler.handle_node_moving_event(event) + assert result is None + assert event not in handler._processed_events + + def test_handle_node_moving_event_already_processed(self): + """Test node moving event handling when event already processed.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.handler._processed_events.add(event) + + result = self.handler.handle_node_moving_event(event) + assert result is None + + def test_handle_node_moving_event_success(self): + """Test successful node moving event handling.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with ( + patch("threading.Timer") as mock_timer, + patch("time.monotonic", return_value=1000), + ): + self.handler.handle_node_moving_event(event) + + # Verify timer was started + mock_timer.assert_called_once_with( + event.ttl, self.handler.handle_node_moved_event + ) + mock_timer.return_value.start.assert_called_once() + + # Verify event was added to processed set + assert event in self.handler._processed_events + + # Verify pool methods were called + self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once() + + def test_handle_node_moved_event(self): + """Test handling of node moved event (cleanup).""" + self.handler.handle_node_moved_event() + + # Verify cleanup methods were called + self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once_with( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + + +class TestMaintenanceEventConnectionHandler: + """Test the MaintenanceEventConnectionHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_connection = Mock() + self.config = MaintenanceEventsConfig(enabled=True, relax_timeout=20) + self.handler = MaintenanceEventConnectionHandler( + self.mock_connection, self.config + ) + + def test_init(self): + """Test MaintenanceEventConnectionHandler initialization.""" + assert self.handler.connection == self.mock_connection + assert self.handler.config == self.config + + def test_handle_event_migrating(self): + """Test handling of NodeMigratingEvent.""" + event = NodeMigratingEvent(id=1, ttl=5) + + with patch.object(self.handler, "handle_migrating_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_migrated(self): + """Test handling of NodeMigratedEvent.""" + event = NodeMigratedEvent(id=1) + + with patch.object( + self.handler, "handle_migration_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_migrating_event_disabled(self): + """Test migrating event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratingEvent(id=1, ttl=5) + + result = handler.handle_migrating_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migrating_event_success(self): + """Test successful migrating event handling.""" + event = NodeMigratingEvent(id=1, ttl=5) + + self.handler.handle_migrating_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) + self.mock_connection.update_tmp_settings.assert_called_once_with( + tmp_relax_timeout=20 + ) + + def test_handle_migration_completed_event_disabled(self): + """Test migration completed event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratedEvent(id=1) + + result = handler.handle_migration_completed_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migration_completed_event_success(self): + """Test successful migration completed event handling.""" + event = NodeMigratedEvent(id=1) + + self.handler.handle_migration_completed_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) + self.mock_connection.update_tmp_settings.assert_called_once_with( + tmp_relax_timeout=-1 + ) From 32a16f019771d75fe9e4178d572e0c3bb0fd74c7 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 18:47:39 +0300 Subject: [PATCH 03/28] Adding integration-like tests for migrating/migrated events handling --- redis/connection.py | 14 + tests/test_maintenance_events_handling.py | 696 ++++++++++++++++++++++ 2 files changed, 710 insertions(+) create mode 100644 tests/test_maintenance_events_handling.py diff --git a/redis/connection.py b/redis/connection.py index f55e1b455c..7755472085 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2209,3 +2209,17 @@ def _update_maintenance_events_config_for_connections( with self._lock: for conn in tuple(self._connections): conn.maintenance_events_config = maintenance_events_config + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + """Override base class method to work with BlockingConnectionPool's structure.""" + with self._lock: + for conn in tuple(self._connections): + if conn: # conn can be None in BlockingConnectionPool + conn.set_maintenance_event_pool_handler( + maintenance_events_pool_handler + ) + conn.maintenance_events_config = ( + maintenance_events_pool_handler.config + ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py new file mode 100644 index 0000000000..6413e24b6e --- /dev/null +++ b/tests/test_maintenance_events_handling.py @@ -0,0 +1,696 @@ +import socket +import threading +import select +from unittest.mock import Mock, patch +import pytest + +from redis import Redis +from redis.connection import ConnectionPool, BlockingConnectionPool +from redis.maintenance_events import ( + MaintenanceEventsConfig, + NodeMigratingEvent, + NodeMigratedEvent, + MaintenanceEventConnectionHandler, + MaintenanceEventPoolHandler, +) + + +class MockSocket: + """Mock socket that simulates Redis protocol responses.""" + + def __init__(self): + self.connected = False + self.sent_data = [] + self.response_queue = [] + self.closed = False + self.command_count = 0 + self.pending_responses = [] + self.current_response_index = 0 + # Track socket timeout changes for maintenance events validation + self.timeout = None + self.thread_timeouts = {} # Track last applied timeout per thread + + def connect(self, address): + """Simulate socket connection.""" + self.connected = True + + def send(self, data): + """Simulate sending data to Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + self.sent_data.append(data) + + # Analyze the command and prepare appropriate response + if b"HELLO" in data: + response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" + self.pending_responses.append(response) + elif b"SET" in data: + response = b"+OK\r\n" + + # Check if this is a key that should trigger a push message + if b"key_receive_migrating_" in data: + # MIGRATING push message before SET key_receive_migrating_X response + # Format: >2\r\n$9\r\nMIGRATING\r\n:10\r\n (2 elements: MIGRATING, ttl) + migrating_push = ">2\r\n$9\r\nMIGRATING\r\n:10\r\n" + response = migrating_push.encode() + response + elif b"key_receive_migrated_" in data: + # MIGRATED push message before SET key_receive_migrated_X response + # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) + migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" + response = migrated_push.encode() + response + + self.pending_responses.append(response) + elif b"GET" in data: + # Extract key and provide appropriate response + if b"hello" in data: + response = b"$5\r\nworld\r\n" + self.pending_responses.append(response) + # Handle thread-specific keys for integration test first (more specific) + elif b"key1_0" in data: + self.pending_responses.append(b"$8\r\nvalue1_0\r\n") + elif b"key_receive_migrating_0" in data: + self.pending_responses.append(b"$8\r\nvalue2_0\r\n") + elif b"key1_1" in data: + self.pending_responses.append(b"$8\r\nvalue1_1\r\n") + elif b"key_receive_migrating_1" in data: + self.pending_responses.append(b"$8\r\nvalue2_1\r\n") + elif b"key1_2" in data: + self.pending_responses.append(b"$8\r\nvalue1_2\r\n") + elif b"key_receive_migrating_2" in data: + self.pending_responses.append(b"$8\r\nvalue2_2\r\n") + # Generic keys (less specific, should come after thread-specific) + elif b"key0" in data: + self.pending_responses.append(b"$6\r\nvalue0\r\n") + elif b"key1" in data: + self.pending_responses.append(b"$6\r\nvalue1\r\n") + elif b"key2" in data: + self.pending_responses.append(b"$6\r\nvalue2\r\n") + else: + self.pending_responses.append(b"$-1\r\n") # NULL response + else: + self.pending_responses.append(b"+OK\r\n") # Default response + + self.command_count += 1 + return len(data) + + def sendall(self, data): + """Simulate sending all data to Redis.""" + return self.send(data) + + def recv(self, bufsize): + """Simulate receiving data from Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + if self.response_queue: + response = self.response_queue.pop(0) + return response[:bufsize] # Respect buffer size + + # Use pending responses that were prepared when commands were sent + if self.pending_responses: + response = self.pending_responses.pop(0) + return response[:bufsize] # Respect buffer size + else: + # No data available - this should block or raise an exception + # For can_read checks, we should indicate no data is available + import errno + + raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + + def fileno(self): + """Return a fake file descriptor for select/poll operations.""" + return 1 # Fake file descriptor + + def close(self): + """Simulate closing the socket.""" + self.closed = True + self.connected = False + + def settimeout(self, timeout): + """Simulate setting socket timeout and track changes per thread.""" + self.timeout = timeout + + # Track last applied timeout per thread + thread_id = threading.current_thread().ident + self.thread_timeouts[thread_id] = timeout + + def setsockopt(self, level, optname, value): + """Simulate setting socket options.""" + pass + + def getpeername(self): + """Simulate getting peer name.""" + return ("127.0.0.1", 6379) + + def getsockname(self): + """Simulate getting socket name.""" + return ("127.0.0.1", 12345) + + def shutdown(self, how): + """Simulate socket shutdown.""" + pass + + +class TestMaintenanceEventsHandling: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if ( + hasattr(sock, "pending_responses") and sock.pending_responses + ) or (hasattr(sock, "response_queue") and sock.response_queue): + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + # Create connection pool with maintenance events (requires RESP3) + self.pool = ConnectionPool( + host="localhost", + port=6379, + max_connections=10, # Increased for multi-threaded tests + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + # Create Redis client + self.redis_client = Redis(connection_pool=self.pool) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + if hasattr(self.pool, "disconnect"): + self.pool.disconnect() + + def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): + """Helper method to validate the current timeout for the calling thread.""" + current_thread_id = threading.current_thread().ident + actual_timeout = None + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + f"Thread {thread_id}: Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout} for thread {current_thread_id}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" + ) + + def test_connection_pool_creation_with_maintenance_events(self): + """Test that connection pool is created with maintenance events configuration.""" + assert ( + self.pool.connection_kwargs.get("maintenance_events_config") == self.config + ) + # Pool should have maintenance events enabled + assert self.pool.maintenance_events_pool_handler_enabled() is True + + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) + self.pool.set_maintenance_events_pool_handler(pool_handler) + + # Validate that the handler is properly set on the pool + assert ( + self.pool.connection_kwargs.get("maintenance_events_pool_handler") + == pool_handler + ) + assert ( + self.pool.connection_kwargs.get("maintenance_events_config") + == pool_handler.config + ) + + # Verify that the pool handler has the correct configuration + assert pool_handler.pool == self.pool + assert pool_handler.config == self.config + + def test_blocking_connection_pool_creation_with_maintenance_events(self): + """Test that BlockingConnectionPool is created with maintenance events configuration.""" + # Create blocking connection pool with maintenance events (requires RESP3) + blocking_pool = BlockingConnectionPool( + host="localhost", + port=6379, + max_connections=3, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + assert ( + blocking_pool.connection_kwargs.get("maintenance_events_config") + == self.config + ) + # Pool should have maintenance events enabled + assert blocking_pool.maintenance_events_pool_handler_enabled() is True + + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(blocking_pool, self.config) + blocking_pool.set_maintenance_events_pool_handler(pool_handler) + + # Validate that the handler is properly set on the blocking pool + assert ( + blocking_pool.connection_kwargs.get("maintenance_events_pool_handler") + == pool_handler + ) + assert ( + blocking_pool.connection_kwargs.get("maintenance_events_config") + == pool_handler.config + ) + + # Verify that the pool handler has the correct configuration + assert pool_handler.pool == blocking_pool + assert pool_handler.config == self.config + + finally: + if hasattr(blocking_pool, "disconnect"): + blocking_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_redis_operations_with_mock_sockets(self, pool_class): + """ + Test basic Redis operations work with mocked sockets and proper response parsing. + Basically with test - the mocked socket is validated. + """ + # Create a pool of the specified type with maintenance events + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=5, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + # Create Redis client with the test pool + test_redis_client = Redis(connection_pool=test_pool) + + # Perform Redis operations that should work with our improved mock responses + result_set = test_redis_client.set("hello", "world") + result_get = test_redis_client.get("hello") + + # Verify operations completed successfully + assert result_set is True + assert result_get == b"world" + + # Verify socket interactions + assert len(self.mock_sockets) >= 1 + assert self.mock_sockets[0].connected + assert len(self.mock_sockets[0].sent_data) >= 2 # HELLO, SET, GET commands + + # Verify that the connection has maintenance event handler + connection = test_pool.get_connection() + assert hasattr(connection, "_maintenance_event_connection_handler") + test_pool.release(connection) + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_multiple_connections_in_pool(self, pool_class): + """Test that multiple connections can be created and used for Redis operations in multiple threads.""" + # Create a pool of the specified type with maintenance events + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=5, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + # Create Redis client with the test pool + test_redis_client = Redis(connection_pool=test_pool) + + # Results storage for thread operations + results = [] + errors = [] + + def redis_operation(key_suffix): + """Perform Redis operations in a thread.""" + try: + # SET operation + set_result = test_redis_client.set( + f"key{key_suffix}", f"value{key_suffix}" + ) + # GET operation + get_result = test_redis_client.get(f"key{key_suffix}") + results.append((set_result, get_result)) + except Exception as e: + errors.append(e) + + # Run operations in multiple threads to force multiple connections + threads = [] + for i in range(3): + thread = threading.Thread(target=redis_operation, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Verify all operations completed successfully + assert len(results) == 3 + for set_result, get_result in results: + assert set_result is True + assert get_result in [b"value0", b"value1", b"value2"] + + # Verify that multiple connections were created with mock sockets + # With threading, both pool types should create multiple sockets for concurrent access + assert len(self.mock_sockets) >= 2, ( + f"Expected multiple sockets due to threading, got {len(self.mock_sockets)}" + ) + + # Verify each connection has maintenance event handler + connection = test_pool.get_connection() + assert hasattr(connection, "_maintenance_event_connection_handler") + test_pool.release(connection) + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migration_related_events_handling_integration(self, pool_class): + """ + Test full integration of migration-related events (MIGRATING/MIGRATED) handling with multiple threads and commands. + + This test validates the complete migration lifecycle: + 1. Creates 3 concurrent threads, each executing 5 Redis commands + 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating_X) + 3. Validates socket timeout is updated to relaxed value (30s) after MIGRATING + 4. Executes commands 3-4 while timeout remains relaxed + 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated_X) + 6. Validates socket timeout is restored after MIGRATED + 7. Tests both ConnectionPool and BlockingConnectionPool implementations + 8. Uses proper RESP3 push message format for realistic protocol simulation + """ + # Create a pool of the specified type with maintenance events + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=10, # Increased for multi-threaded tests + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + # Create Redis client with the test pool + test_redis_client = Redis(connection_pool=test_pool) + + # Results storage for thread operations + results = [] + errors = [] + + def redis_operations_with_maintenance_events(thread_id): + """Perform Redis operations with maintenance events in a thread.""" + try: + # Command 1: Initial command + result1 = test_redis_client.set( + f"key1_{thread_id}", f"value1_{thread_id}" + ) + + # Validate Command 1 result + assert result1 is True, ( + f"Thread {thread_id}: Command 1 (SET key1) failed" + ) + + # Command 2: This SET command will receive MIGRATING push message before response + result2 = test_redis_client.set( + f"key_receive_migrating_{thread_id}", f"value2_{thread_id}" + ) + + # Validate Command 2 result + assert result2 is True, ( + f"Thread {thread_id}: Command 2 (SET key2) failed" + ) + + # Step 4: Validate timeout was updated to relaxed value after MIGRATING + self._validate_current_timeout_for_thread(thread_id, 30) + + # Command 3: Another command while timeout is still relaxed + result3 = test_redis_client.get(f"key1_{thread_id}") + + # Validate Command 3 result + expected_value3 = f"value1_{thread_id}".encode() + assert result3 == expected_value3, ( + f"Thread {thread_id}: Command 3 (GET key1) failed. " + f"Expected {expected_value3}, got {result3}" + ) + + # Command 4: Execute command (step 5) + result4 = test_redis_client.get( + f"key_receive_migrating_{thread_id}" + ) + + # Validate Command 4 result + expected_value4 = f"value2_{thread_id}".encode() + assert result4 == expected_value4, ( + f"Thread {thread_id}: Command 4 (GET key_receive_migrating) failed. " + f"Expected {expected_value4}, got {result4}" + ) + + # Step 6: Validate socket timeout is still relaxed during commands 3-4 + self._validate_current_timeout_for_thread(thread_id, 30) + + # Command 5: This SET command will receive + # MIGRATED push message before actual response + result5 = test_redis_client.set( + f"key_receive_migrated_{thread_id}", f"value3_{thread_id}" + ) + + # Validate Command 5 result + assert result5 is True, ( + f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" + ) + + # Step 8: Validate socket timeout is reversed back to original after MIGRATED + self._validate_current_timeout_for_thread(thread_id, None) + + results.append( + { + "thread_id": thread_id, + "success": True, + } + ) + + except Exception as e: + errors.append(f"Thread {thread_id}: {e}") + + # Run operations in multiple threads (step 1) + threads = [] + for i in range(3): + thread = threading.Thread( + target=redis_operations_with_maintenance_events, + args=(i,), + name=str(i), + ) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all threads completed successfully + successful_threads = len(results) + assert successful_threads == 3, ( + f"Expected 3 successful threads, got {successful_threads}. " + f"Errors: {errors}" + ) + + # Verify maintenance events were processed correctly across all threads + # Note: Different pool types may create different numbers of sockets + # The key is that we have at least 1 socket and all threads succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + def test_migrating_event_with_disabled_relax_timeout(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test migrating event handling when relax timeout is disabled.""" + # Create config with disabled relax timeout + disabled_config = MaintenanceEventsConfig( + enabled=True, + relax_timeout=-1, # Disabled + ) + + # Create new pool with disabled config + disabled_pool = ConnectionPool( + host="localhost", + port=6379, + protocol=3, # Required for maintenance events + maintenance_events_config=disabled_config, + ) + + try: + # Get a connection + connection = disabled_pool.get_connection() + + # Mock the connection's timeout update methods + connection.update_current_socket_timeout = Mock() + connection.update_tmp_settings = Mock() + + # Create and handle migrating event + migrating_event = NodeMigratingEvent(id=1, ttl=10) + result = connection._maintenance_event_connection_handler.handle_event( + migrating_event + ) + + # Verify that no timeout updates were made (relax is disabled) + assert result is None + connection.update_current_socket_timeout.assert_not_called() + connection.update_tmp_settings.assert_not_called() + + finally: + if hasattr(disabled_pool, "disconnect"): + disabled_pool.disconnect() + + def test_pool_handler_with_migrating_event(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test that pool handler correctly handles migrating events.""" + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) + + # Create a migrating event (not handled by pool handler) + migrating_event = NodeMigratingEvent(id=1, ttl=5) + + # Pool handler should return None for migrating events (not its responsibility) + result = pool_handler.handle_event(migrating_event) + assert result is None + + def test_connection_timeout_restoration_after_event(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test that connection timeout is properly restored after maintenance event.""" + # Establish connection + self.redis_client.set("test", "value") + + connection = self.pool.get_connection() + + # Mock timeout methods + connection.update_current_socket_timeout = Mock() + connection.update_tmp_settings = Mock() + + # Simulate migrating event + migrating_event = NodeMigratingEvent(id=1, ttl=5) + connection._maintenance_event_connection_handler.handle_migrating_event( + migrating_event + ) + + # Verify relax timeout was applied + connection.update_current_socket_timeout.assert_called_with(30) + connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=30) + + # Reset mocks + connection.update_current_socket_timeout.reset_mock() + connection.update_tmp_settings.reset_mock() + + # Simulate migration completed event + from redis.maintenance_events import NodeMigratedEvent + + migrated_event = NodeMigratedEvent(id=1) + connection._maintenance_event_connection_handler.handle_migration_completed_event( + migrated_event + ) + + # Verify timeout was restored + connection.update_current_socket_timeout.assert_called_with( + -1 + ) # Restore original + connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=-1) + + self.pool.release(connection) + + def test_socket_error_handling_during_operations(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test that socket errors are properly handled during Redis operations.""" + # Create a connection first to ensure we have a mock socket + connection = self.pool.get_connection() + + # Set up a socket that will fail + if self.mock_sockets: + self.mock_sockets[0].closed = True + + # Attempt Redis operation that should fail due to closed socket + with pytest.raises( + (ConnectionError, OSError, Exception) + ): # Should raise connection-related exception + # Try to use the connection with a closed socket + connection.send_command("PING") + + # Release the connection + self.pool.release(connection) + + def test_maintenance_events_with_concurrent_operations(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test maintenance events handling with concurrent Redis operations.""" + + # Perform concurrent operations + def redis_operation(key_suffix): + try: + return self.redis_client.set( + f"concurrent_key_{key_suffix}", f"value_{key_suffix}" + ) + except Exception: + return False + + # Simulate concurrent operations + threads = [] + results = [] + + for i in range(3): + thread = threading.Thread( + target=lambda i=i: results.append(redis_operation(i)) + ) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # During concurrent operations, simulate a maintenance event + if self.pool.connection_kwargs.get("maintenance_events_config"): + migrating_event = NodeMigratingEvent(id=1, ttl=5) + # Create a pool handler to test event handling + pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) + result = pool_handler.handle_event(migrating_event) + assert result is None # Pool handler doesn't handle migrating events + + # Verify that some operations completed successfully + # (Some might fail due to mock socket limitations, but that's expected) + assert len(results) == 3 From 8bfdf131f18c166b14fc26aee0f62b94533800fd Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 18:49:22 +0300 Subject: [PATCH 04/28] Removed unused imports --- tests/test_maintenance_events_handling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 6413e24b6e..80590840a0 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1,6 +1,5 @@ import socket import threading -import select from unittest.mock import Mock, patch import pytest @@ -9,8 +8,6 @@ from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, - NodeMigratedEvent, - MaintenanceEventConnectionHandler, MaintenanceEventPoolHandler, ) From 33d7295e1cc80f34877862c6e63cde1c4eb95c68 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 19:06:37 +0300 Subject: [PATCH 05/28] Revert changing of the default retry object initialization for connection pool - this should be a separate PR --- redis/connection.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 7755472085..9a434848ca 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -24,7 +24,7 @@ from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .auth.token import TokenInterface -from .backoff import ExponentialWithJitterBackoff +from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher from .exceptions import ( @@ -323,15 +323,16 @@ def __init__( # Add TimeoutError to the errors list to retry on retry_on_error.append(TimeoutError) self.retry_on_error = retry_on_error - if retry is None: - self.retry = Retry( - backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 - ) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - if retry_on_error: + if retry or retry_on_error: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors self.retry.update_supported_errors(retry_on_error) + else: + self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check = 0 self.redis_connect_func = redis_connect_func From 346097feab59e195fc0a3c1c9427f5f92e616161 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 14 Jul 2025 15:41:47 +0300 Subject: [PATCH 06/28] Complete migrating/migrated integration-like tests --- tests/test_maintenance_events_handling.py | 444 ++++++++++------------ 1 file changed, 195 insertions(+), 249 deletions(-) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 80590840a0..6a687da1b0 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -17,6 +17,7 @@ class MockSocket: def __init__(self): self.connected = False + self.address = None self.sent_data = [] self.response_queue = [] self.closed = False @@ -30,6 +31,7 @@ def __init__(self): def connect(self, address): """Simulate socket connection.""" self.connected = True + self.address = address def send(self, data): """Simulate sending data to Redis.""" @@ -187,24 +189,40 @@ def mock_select(rlist, wlist, xlist, timeout=0): enabled=True, proactive_reconnect=True, relax_timeout=30 ) - # Create connection pool with maintenance events (requires RESP3) - self.pool = ConnectionPool( - host="localhost", - port=6379, - max_connections=10, # Increased for multi-threaded tests - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) - - # Create Redis client - self.redis_client = Redis(connection_pool=self.pool) - def teardown_method(self): """Clean up test fixtures.""" self.socket_patcher.stop() self.select_patcher.stop() - if hasattr(self.pool, "disconnect"): - self.pool.disconnect() + + def _get_client( + self, pool_class, max_connections=10, maintenance_events_config=None + ): + """Helper method to create a pool and Redis client with maintenance events configuration. + + Args: + pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool) + max_connections: Maximum number of connections in the pool (default: 10) + maintenance_events_config: Optional MaintenanceEventsConfig to use. If not provided, + uses self.config from setup_method (default: None) + + Returns: + tuple: (test_pool, test_redis_client) + """ + config = ( + maintenance_events_config + if maintenance_events_config is not None + else self.config + ) + + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=max_connections, + protocol=3, # Required for maintenance events + maintenance_events_config=config, + ) + test_redis_client = Redis(connection_pool=test_pool) + return test_pool, test_redis_client def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): """Helper method to validate the current timeout for the calling thread.""" @@ -221,72 +239,42 @@ def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" ) - def test_connection_pool_creation_with_maintenance_events(self): - """Test that connection pool is created with maintenance events configuration.""" - assert ( - self.pool.connection_kwargs.get("maintenance_events_config") == self.config - ) - # Pool should have maintenance events enabled - assert self.pool.maintenance_events_pool_handler_enabled() is True - - # Create and set a pool handler - pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) - self.pool.set_maintenance_events_pool_handler(pool_handler) - - # Validate that the handler is properly set on the pool - assert ( - self.pool.connection_kwargs.get("maintenance_events_pool_handler") - == pool_handler - ) - assert ( - self.pool.connection_kwargs.get("maintenance_events_config") - == pool_handler.config - ) - - # Verify that the pool handler has the correct configuration - assert pool_handler.pool == self.pool - assert pool_handler.config == self.config - - def test_blocking_connection_pool_creation_with_maintenance_events(self): - """Test that BlockingConnectionPool is created with maintenance events configuration.""" - # Create blocking connection pool with maintenance events (requires RESP3) - blocking_pool = BlockingConnectionPool( - host="localhost", - port=6379, - max_connections=3, - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_connection_pool_creation_with_maintenance_events(self, pool_class): + """Test that connection pools are created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + max_connections = 3 if pool_class == BlockingConnectionPool else 10 + test_pool, _ = self._get_client(pool_class, max_connections=max_connections) try: assert ( - blocking_pool.connection_kwargs.get("maintenance_events_config") + test_pool.connection_kwargs.get("maintenance_events_config") == self.config ) # Pool should have maintenance events enabled - assert blocking_pool.maintenance_events_pool_handler_enabled() is True + assert test_pool.maintenance_events_pool_handler_enabled() is True # Create and set a pool handler - pool_handler = MaintenanceEventPoolHandler(blocking_pool, self.config) - blocking_pool.set_maintenance_events_pool_handler(pool_handler) + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + test_pool.set_maintenance_events_pool_handler(pool_handler) - # Validate that the handler is properly set on the blocking pool + # Validate that the handler is properly set on the pool assert ( - blocking_pool.connection_kwargs.get("maintenance_events_pool_handler") + test_pool.connection_kwargs.get("maintenance_events_pool_handler") == pool_handler ) assert ( - blocking_pool.connection_kwargs.get("maintenance_events_config") + test_pool.connection_kwargs.get("maintenance_events_config") == pool_handler.config ) # Verify that the pool handler has the correct configuration - assert pool_handler.pool == blocking_pool + assert pool_handler.pool == test_pool assert pool_handler.config == self.config finally: - if hasattr(blocking_pool, "disconnect"): - blocking_pool.disconnect() + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_redis_operations_with_mock_sockets(self, pool_class): @@ -294,19 +282,10 @@ def test_redis_operations_with_mock_sockets(self, pool_class): Test basic Redis operations work with mocked sockets and proper response parsing. Basically with test - the mocked socket is validated. """ - # Create a pool of the specified type with maintenance events - test_pool = pool_class( - host="localhost", - port=6379, - max_connections=5, - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + # Create a pool and Redis client with maintenance events + test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) try: - # Create Redis client with the test pool - test_redis_client = Redis(connection_pool=test_pool) - # Perform Redis operations that should work with our improved mock responses result_set = test_redis_client.set("hello", "world") result_get = test_redis_client.get("hello") @@ -332,19 +311,10 @@ def test_redis_operations_with_mock_sockets(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_multiple_connections_in_pool(self, pool_class): """Test that multiple connections can be created and used for Redis operations in multiple threads.""" - # Create a pool of the specified type with maintenance events - test_pool = pool_class( - host="localhost", - port=6379, - max_connections=5, - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + # Create a pool and Redis client with maintenance events + test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) try: - # Create Redis client with the test pool - test_redis_client = Redis(connection_pool=test_pool) - # Results storage for thread operations results = [] errors = [] @@ -397,6 +367,44 @@ def redis_operation(key_suffix): if hasattr(test_pool, "disconnect"): test_pool.disconnect() + def test_pool_handler_with_migrating_event(self): + """Test that pool handler correctly handles migrating events.""" + # Create a pool and Redis client with maintenance events + test_pool, _ = self._get_client(ConnectionPool) + + try: + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + + # Create a migrating event (not handled by pool handler) + migrating_event = NodeMigratingEvent(id=1, ttl=5) + + # Mock the required functions + with ( + patch.object( + pool_handler, "remove_expired_notifications" + ) as mock_remove_expired, + patch.object( + pool_handler, "handle_node_moving_event" + ) as mock_handle_moving, + patch("redis.maintenance_events.logging.error") as mock_logging_error, + ): + # Pool handler should return None for migrating events (not its responsibility) + pool_handler.handle_event(migrating_event) + + # Validate that remove_expired_notifications has been called once + mock_remove_expired.assert_called_once() + + # Validate that handle_node_moving_event hasn't been called + mock_handle_moving.assert_not_called() + + # Validate that logging.error has been called once + mock_logging_error.assert_called_once() + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migration_related_events_handling_integration(self, pool_class): """ @@ -412,19 +420,10 @@ def test_migration_related_events_handling_integration(self, pool_class): 7. Tests both ConnectionPool and BlockingConnectionPool implementations 8. Uses proper RESP3 push message format for realistic protocol simulation """ - # Create a pool of the specified type with maintenance events - test_pool = pool_class( - host="localhost", - port=6379, - max_connections=10, # Increased for multi-threaded tests - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + # Create a pool and Redis client with maintenance events + test_pool, test_redis_client = self._get_client(pool_class, max_connections=10) try: - # Create Redis client with the test pool - test_redis_client = Redis(connection_pool=test_pool) - # Results storage for thread operations results = [] errors = [] @@ -433,63 +432,60 @@ def redis_operations_with_maintenance_events(thread_id): """Perform Redis operations with maintenance events in a thread.""" try: # Command 1: Initial command - result1 = test_redis_client.set( - f"key1_{thread_id}", f"value1_{thread_id}" - ) + key1 = f"key1_{thread_id}" + value1 = f"value1_{thread_id}" + result1 = test_redis_client.set(key1, value1) # Validate Command 1 result - assert result1 is True, ( - f"Thread {thread_id}: Command 1 (SET key1) failed" - ) + erros_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" + assert result1 is True, erros_msg # Command 2: This SET command will receive MIGRATING push message before response - result2 = test_redis_client.set( - f"key_receive_migrating_{thread_id}", f"value2_{thread_id}" - ) + key_migrating = f"key_receive_migrating_{thread_id}" + value_migrating = f"value2_{thread_id}" + result2 = test_redis_client.set(key_migrating, value_migrating) # Validate Command 2 result - assert result2 is True, ( - f"Thread {thread_id}: Command 2 (SET key2) failed" - ) + erros_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" + assert result2 is True, erros_msg # Step 4: Validate timeout was updated to relaxed value after MIGRATING self._validate_current_timeout_for_thread(thread_id, 30) # Command 3: Another command while timeout is still relaxed - result3 = test_redis_client.get(f"key1_{thread_id}") + result3 = test_redis_client.get(key1) # Validate Command 3 result - expected_value3 = f"value1_{thread_id}".encode() - assert result3 == expected_value3, ( + expected_value3 = value1.encode() + errors_msg = ( f"Thread {thread_id}: Command 3 (GET key1) failed. " f"Expected {expected_value3}, got {result3}" ) + assert result3 == expected_value3, errors_msg # Command 4: Execute command (step 5) - result4 = test_redis_client.get( - f"key_receive_migrating_{thread_id}" - ) + result4 = test_redis_client.get(key_migrating) # Validate Command 4 result - expected_value4 = f"value2_{thread_id}".encode() - assert result4 == expected_value4, ( + expected_value4 = value_migrating.encode() + errors_msg = ( f"Thread {thread_id}: Command 4 (GET key_receive_migrating) failed. " f"Expected {expected_value4}, got {result4}" ) + assert result4 == expected_value4, errors_msg # Step 6: Validate socket timeout is still relaxed during commands 3-4 self._validate_current_timeout_for_thread(thread_id, 30) # Command 5: This SET command will receive # MIGRATED push message before actual response - result5 = test_redis_client.set( - f"key_receive_migrated_{thread_id}", f"value3_{thread_id}" - ) + key_migrated = f"key_receive_migrated_{thread_id}" + value_migrated = f"value3_{thread_id}" + result5 = test_redis_client.set(key_migrated, value_migrated) # Validate Command 5 result - assert result5 is True, ( - f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" - ) + errors_msg = f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" + assert result5 is True, errors_msg # Step 8: Validate socket timeout is reversed back to original after MIGRATED self._validate_current_timeout_for_thread(thread_id, None) @@ -537,157 +533,107 @@ def redis_operations_with_maintenance_events(thread_id): if hasattr(test_pool, "disconnect"): test_pool.disconnect() - def test_migrating_event_with_disabled_relax_timeout(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test migrating event handling when relax timeout is disabled.""" + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_event_with_disabled_relax_timeout(self, pool_class): + """ + Test migrating event handling when relax timeout is disabled. + + This test validates that when relax_timeout is disabled (-1): + 1. MIGRATING events are received and processed + 2. No timeout updates are applied to connections + 3. Socket timeouts remain unchanged during migration events + 4. Tests both ConnectionPool and BlockingConnectionPool implementations + """ # Create config with disabled relax timeout disabled_config = MaintenanceEventsConfig( enabled=True, - relax_timeout=-1, # Disabled + relax_timeout=-1, # This means the relax timeout is Disabled ) - # Create new pool with disabled config - disabled_pool = ConnectionPool( - host="localhost", - port=6379, - protocol=3, # Required for maintenance events - maintenance_events_config=disabled_config, + # Create a pool and Redis client with disabled relax timeout config + test_pool, test_redis_client = self._get_client( + pool_class, max_connections=5, maintenance_events_config=disabled_config ) try: - # Get a connection - connection = disabled_pool.get_connection() - - # Mock the connection's timeout update methods - connection.update_current_socket_timeout = Mock() - connection.update_tmp_settings = Mock() - - # Create and handle migrating event - migrating_event = NodeMigratingEvent(id=1, ttl=10) - result = connection._maintenance_event_connection_handler.handle_event( - migrating_event - ) - - # Verify that no timeout updates were made (relax is disabled) - assert result is None - connection.update_current_socket_timeout.assert_not_called() - connection.update_tmp_settings.assert_not_called() - - finally: - if hasattr(disabled_pool, "disconnect"): - disabled_pool.disconnect() - - def test_pool_handler_with_migrating_event(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test that pool handler correctly handles migrating events.""" - # Create and set a pool handler - pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) - - # Create a migrating event (not handled by pool handler) - migrating_event = NodeMigratingEvent(id=1, ttl=5) + # Results storage for thread operations + results = [] + errors = [] - # Pool handler should return None for migrating events (not its responsibility) - result = pool_handler.handle_event(migrating_event) - assert result is None + def redis_operations_with_disabled_relax(thread_id): + """Perform Redis operations with disabled relax timeout in a thread.""" + try: + # Command 1: Initial command + key1 = f"key1_{thread_id}" + value1 = f"value1_{thread_id}" + result1 = test_redis_client.set(key1, value1) - def test_connection_timeout_restoration_after_event(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test that connection timeout is properly restored after maintenance event.""" - # Establish connection - self.redis_client.set("test", "value") + # Validate Command 1 result + errors_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" + assert result1 is True, errors_msg - connection = self.pool.get_connection() + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = f"key_receive_migrating_{thread_id}" + value_migrating = f"value2_{thread_id}" + result2 = test_redis_client.set(key_migrating, value_migrating) - # Mock timeout methods - connection.update_current_socket_timeout = Mock() - connection.update_tmp_settings = Mock() + # Validate Command 2 result + errors_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" + assert result2 is True, errors_msg - # Simulate migrating event - migrating_event = NodeMigratingEvent(id=1, ttl=5) - connection._maintenance_event_connection_handler.handle_migrating_event( - migrating_event - ) + # Validate timeout was NOT updated (relax is disabled) + # Should remain at default timeout (None), not relaxed to 30s + self._validate_current_timeout_for_thread(thread_id, None) - # Verify relax timeout was applied - connection.update_current_socket_timeout.assert_called_with(30) - connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=30) + # Command 3: Another command to verify timeout remains unchanged + result3 = test_redis_client.get(key1) - # Reset mocks - connection.update_current_socket_timeout.reset_mock() - connection.update_tmp_settings.reset_mock() + # Validate Command 3 result + expected_value3 = value1.encode() + errors_msg = ( + f"Thread {thread_id}: Command 3 (GET key1) failed. " + f"Expected: {expected_value3}, Got: {result3}" + ) + assert result3 == expected_value3, errors_msg - # Simulate migration completed event - from redis.maintenance_events import NodeMigratedEvent + results.append( + { + "thread_id": thread_id, + "success": True, + } + ) - migrated_event = NodeMigratedEvent(id=1) - connection._maintenance_event_connection_handler.handle_migration_completed_event( - migrated_event - ) + except Exception as e: + errors.append(f"Thread {thread_id}: {str(e)}") - # Verify timeout was restored - connection.update_current_socket_timeout.assert_called_with( - -1 - ) # Restore original - connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=-1) - - self.pool.release(connection) - - def test_socket_error_handling_during_operations(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test that socket errors are properly handled during Redis operations.""" - # Create a connection first to ensure we have a mock socket - connection = self.pool.get_connection() - - # Set up a socket that will fail - if self.mock_sockets: - self.mock_sockets[0].closed = True - - # Attempt Redis operation that should fail due to closed socket - with pytest.raises( - (ConnectionError, OSError, Exception) - ): # Should raise connection-related exception - # Try to use the connection with a closed socket - connection.send_command("PING") - - # Release the connection - self.pool.release(connection) - - def test_maintenance_events_with_concurrent_operations(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test maintenance events handling with concurrent Redis operations.""" - - # Perform concurrent operations - def redis_operation(key_suffix): - try: - return self.redis_client.set( - f"concurrent_key_{key_suffix}", f"value_{key_suffix}" + # Run operations in multiple threads to test concurrent behavior + threads = [] + for i in range(3): + thread = threading.Thread( + target=redis_operations_with_disabled_relax, args=(i,) ) - except Exception: - return False + threads.append(thread) + thread.start() - # Simulate concurrent operations - threads = [] - results = [] + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" - for i in range(3): - thread = threading.Thread( - target=lambda i=i: results.append(redis_operation(i)) + # Verify all operations completed successfully + assert len(results) == 3, ( + f"Expected 3 successful threads, got {len(results)}" ) - threads.append(thread) - thread.start() - # Wait for all threads to complete - for thread in threads: - thread.join() + # Verify maintenance events were processed correctly across all threads + # Note: Different pool types may create different numbers of sockets + # The key is that we have at least 1 socket and all threads succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) - # During concurrent operations, simulate a maintenance event - if self.pool.connection_kwargs.get("maintenance_events_config"): - migrating_event = NodeMigratingEvent(id=1, ttl=5) - # Create a pool handler to test event handling - pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) - result = pool_handler.handle_event(migrating_event) - assert result is None # Pool handler doesn't handle migrating events - - # Verify that some operations completed successfully - # (Some might fail due to mock socket limitations, but that's expected) - assert len(results) == 3 + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() From f3a9a71a21e39f08b175fe7a43d7ab5ddda94260 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 15 Jul 2025 17:15:02 +0300 Subject: [PATCH 07/28] Adding moving integration-like tests --- redis/_parsers/base.py | 9 +- redis/connection.py | 26 +- redis/maintenance_events.py | 39 +- tests/test_maintenance_events_handling.py | 604 ++++++++++++++++++---- 4 files changed, 557 insertions(+), 121 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index aa5a6b0f12..f2670e43b0 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -129,9 +129,10 @@ def __del__(self): def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock - self._buffer = SocketBuffer( - self._sock, self.socket_read_size, connection.socket_timeout - ) + timeout = connection.socket_timeout + if connection.tmp_relax_timeout != -1: + timeout = connection.tmp_relax_timeout + self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout) self.encoder = connection.encoder def on_disconnect(self): @@ -203,7 +204,7 @@ def handle_push_response(self, response, **kwargs): return self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: if msg_type in _MOVING_MESSAGE: - host, port = response[2].split(":") + host, port = response[2].decode().split(":") ttl = response[1] id = 1 # Hardcoded value for sync parser notification = NodeMovingEvent(id, host, port, ttl) diff --git a/redis/connection.py b/redis/connection.py index 9a434848ca..81a80d0903 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -807,6 +807,10 @@ def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): f" to timeout {timeout}; relax_timeout: {relax_timeout}" ) self._sock.settimeout(timeout) + self.update_parser_buffer_timeout(timeout) + + def update_parser_buffer_timeout(self, timeout: Optional[float] = None): + if self._parser and self._parser._buffer: self._parser._buffer.socket_timeout = timeout def update_tmp_settings( @@ -1901,7 +1905,7 @@ def disconnect_and_reconfigure_free_connections( def update_connections_current_timeout( self, relax_timeout: Optional[float], - include_available_connections: bool = False, + include_free_connections: bool = False, ): """ Update the timeout either for all connections in the pool or just for the ones in use. @@ -1919,7 +1923,7 @@ def update_connections_current_timeout( for conn in self._in_use_connections: self._update_connection_timeout(conn, relax_timeout) - if include_available_connections: + if include_free_connections: for conn in self._available_connections: self._update_connection_timeout(conn, relax_timeout) @@ -2164,8 +2168,6 @@ def update_active_connections_for_reconnect( connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: - if tmp_relax_timeout != -1: - conn.update_socket_timeout(tmp_relax_timeout) self._update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2184,14 +2186,24 @@ def disconnect_and_reconfigure_free_connections( conn, tmp_host_address, tmp_relax_timeout ) - def update_connections_current_timeout(self, relax_timeout: Optional[float] = None): + def update_connections_current_timeout( + self, + relax_timeout: Optional[float] = None, + include_free_connections: bool = False, + ): logging.debug( f"***** Blocking Pool --> Updating timeouts. relax_timeout: {relax_timeout}" ) with self._lock: - for conn in tuple(self._connections): - self._update_connection_timeout(conn, relax_timeout) + if include_free_connections: + for conn in tuple(self._connections): + self._update_connection_timeout(conn, relax_timeout) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + self._update_connection_timeout(conn, relax_timeout) def update_connections_tmp_settings( self, diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index d818a846b8..5b2091472d 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -323,25 +323,6 @@ def handle_event(self, notification: MaintenanceEvent): else: logging.error(f"Unhandled notification type: {notification}") - def handle_node_moved_event(self): - with self._lock: - self.pool.update_connection_kwargs_with_tmp_settings( - tmp_host_address=None, - tmp_relax_timeout=-1, - ) - with self.pool._lock: - if self.config.is_relax_timeouts_enabled(): - # reset the timeout for existing connections - self.pool.update_connections_current_timeout( - relax_timeout=-1, include_available_connections=True - ) - logging.debug("***** MOVING END--> TIMEOUTS RESET") - - self.pool.update_connections_tmp_settings( - tmp_host_address=None, tmp_relax_timeout=-1 - ) - logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") - def handle_node_moving_event(self, event: NodeMovingEvent): if ( not self.config.proactive_reconnect @@ -403,6 +384,26 @@ def handle_node_moving_event(self, event: NodeMovingEvent): f"###### MOVING total execution time: {execution_time_us:.0f} microseconds" ) + def handle_node_moved_event(self): + logging.debug("***** MOVING END--> Starting to revert the changes.") + with self._lock: + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + with self.pool._lock: + if self.config.is_relax_timeouts_enabled(): + # reset the timeout for existing connections + self.pool.update_connections_current_timeout( + relax_timeout=-1, include_free_connections=True + ) + logging.debug("***** MOVING END--> TIMEOUTS RESET") + + self.pool.update_connections_tmp_settings( + tmp_host_address=None, tmp_relax_timeout=-1 + ) + logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") + class MaintenanceEventConnectionHandler: def __init__( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 6a687da1b0..c04c6c066e 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1,10 +1,12 @@ import socket import threading -from unittest.mock import Mock, patch +from typing import List +from unittest.mock import patch import pytest +from time import sleep from redis import Redis -from redis.connection import ConnectionPool, BlockingConnectionPool +from redis.connection import AbstractConnection, ConnectionPool, BlockingConnectionPool from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, @@ -15,18 +17,21 @@ class MockSocket: """Mock socket that simulates Redis protocol responses.""" + AFTER_MOVING_ADDRESS = "1.2.3.4:6379" + DEFAULT_ADDRESS = "12.45.34.56:6379" + MOVING_TIMEOUT = 1 + def __init__(self): self.connected = False self.address = None self.sent_data = [] - self.response_queue = [] self.closed = False self.command_count = 0 self.pending_responses = [] - self.current_response_index = 0 # Track socket timeout changes for maintenance events validation self.timeout = None self.thread_timeouts = {} # Track last applied timeout per thread + self.moving_sent = False def connect(self, address): """Simulate socket connection.""" @@ -57,6 +62,12 @@ def send(self, data): # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" response = migrated_push.encode() + response + elif b"key_receive_moving_" in data: + # MOVING push message before SET key_receive_moving_X response + # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) + # Note: Using + instead of $ to send as simple string instead of bulk string + moving_push = f">3\r\n$6\r\nMOVING\r\n:{MockSocket.MOVING_TIMEOUT}\r\n+{MockSocket.AFTER_MOVING_ADDRESS}\r\n" + response = moving_push.encode() + response self.pending_responses.append(response) elif b"GET" in data: @@ -69,14 +80,20 @@ def send(self, data): self.pending_responses.append(b"$8\r\nvalue1_0\r\n") elif b"key_receive_migrating_0" in data: self.pending_responses.append(b"$8\r\nvalue2_0\r\n") + elif b"key_receive_moving_0" in data: + self.pending_responses.append(b"$8\r\nvalue3_0\r\n") elif b"key1_1" in data: self.pending_responses.append(b"$8\r\nvalue1_1\r\n") elif b"key_receive_migrating_1" in data: self.pending_responses.append(b"$8\r\nvalue2_1\r\n") + elif b"key_receive_moving_1" in data: + self.pending_responses.append(b"$8\r\nvalue3_1\r\n") elif b"key1_2" in data: self.pending_responses.append(b"$8\r\nvalue1_2\r\n") elif b"key_receive_migrating_2" in data: self.pending_responses.append(b"$8\r\nvalue2_2\r\n") + elif b"key_receive_moving_2" in data: + self.pending_responses.append(b"$8\r\nvalue3_2\r\n") # Generic keys (less specific, should come after thread-specific) elif b"key0" in data: self.pending_responses.append(b"$6\r\nvalue0\r\n") @@ -100,13 +117,12 @@ def recv(self, bufsize): """Simulate receiving data from Redis.""" if self.closed: raise ConnectionError("Socket is closed") - if self.response_queue: - response = self.response_queue.pop(0) - return response[:bufsize] # Respect buffer size # Use pending responses that were prepared when commands were sent if self.pending_responses: response = self.pending_responses.pop(0) + if b"MOVING" in response: + self.moving_sent = True return response[:bufsize] # Respect buffer size else: # No data available - this should block or raise an exception @@ -123,26 +139,33 @@ def close(self): """Simulate closing the socket.""" self.closed = True self.connected = False + self.address = None + self.timeout = None + self.thread_timeouts = {} def settimeout(self, timeout): """Simulate setting socket timeout and track changes per thread.""" self.timeout = timeout - # Track last applied timeout per thread + # Track last applied timeout with thread_id information added thread_id = threading.current_thread().ident self.thread_timeouts[thread_id] = timeout + def gettimeout(self): + """Simulate getting socket timeout.""" + return self.timeout + def setsockopt(self, level, optname, value): """Simulate setting socket options.""" pass def getpeername(self): """Simulate getting peer name.""" - return ("127.0.0.1", 6379) + return self.address def getsockname(self): """Simulate getting socket name.""" - return ("127.0.0.1", 12345) + return (self.address.split(":")[0], 12345) def shutdown(self, how): """Simulate socket shutdown.""" @@ -173,9 +196,7 @@ def mock_select(rlist, wlist, xlist, timeout=0): for sock in rlist: if hasattr(sock, "connected") and sock.connected and not sock.closed: # Only return socket as ready if it actually has data to read - if ( - hasattr(sock, "pending_responses") and sock.pending_responses - ) or (hasattr(sock, "response_queue") and sock.response_queue): + if hasattr(sock, "pending_responses") and sock.pending_responses: ready_sockets.append(sock) # Don't return socket as ready just because it received commands # Only when there are actual responses available @@ -195,7 +216,11 @@ def teardown_method(self): self.select_patcher.stop() def _get_client( - self, pool_class, max_connections=10, maintenance_events_config=None + self, + pool_class, + max_connections=10, + maintenance_events_config=None, + setup_pool_handler=False, ): """Helper method to create a pool and Redis client with maintenance events configuration. @@ -204,6 +229,7 @@ def _get_client( max_connections: Maximum number of connections in the pool (default: 10) maintenance_events_config: Optional MaintenanceEventsConfig to use. If not provided, uses self.config from setup_method (default: None) + setup_pool_handler: Whether to set up pool handler for moving events (default: False) Returns: tuple: (test_pool, test_redis_client) @@ -215,19 +241,30 @@ def _get_client( ) test_pool = pool_class( - host="localhost", - port=6379, + host=MockSocket.DEFAULT_ADDRESS.split(":")[0], + port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), max_connections=max_connections, protocol=3, # Required for maintenance events maintenance_events_config=config, ) test_redis_client = Redis(connection_pool=test_pool) - return test_pool, test_redis_client + + # Set up pool handler for moving events if requested + if setup_pool_handler: + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + return test_redis_client def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): """Helper method to validate the current timeout for the calling thread.""" - current_thread_id = threading.current_thread().ident actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident for sock in self.mock_sockets: if current_thread_id in sock.thread_timeouts: actual_timeout = sock.thread_timeouts[current_thread_id] @@ -235,16 +272,121 @@ def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): assert actual_timeout == expected_timeout, ( f"Thread {thread_id}: Expected timeout ({expected_timeout}), " - f"but found timeout: {actual_timeout} for thread {current_thread_id}. " + f"but found timeout: {actual_timeout} for thread {thread_id}. " f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" ) + def _validate_disconnected(self, expected_count): + """Helper method to validate all socket timeouts""" + disconnected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.closed: + disconnected_sockets_count += 1 + assert disconnected_sockets_count == expected_count + + def _validate_connected(self, expected_count): + """Helper method to validate all socket timeouts""" + connected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.connected: + connected_sockets_count += 1 + assert connected_sockets_count == expected_count + + def _validate_in_use_connections_state( + self, in_use_connections: List[AbstractConnection] + ): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for connection in in_use_connections: + assert connection._should_reconnect is True + assert ( + connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert connection.tmp_relax_timeout == self.config.relax_timeout + assert connection._sock.gettimeout() == self.config.relax_timeout + assert connection._sock.connected is True + assert ( + connection._sock.getpeername()[0] + == MockSocket.DEFAULT_ADDRESS.split(":")[0] + ) + + def _validate_free_connections_state( + self, + pool, + tmp_host_address, + relax_timeout, + should_be_connected_count, + connected_to_tmp_addres=False, + ): + """Helper method to validate state of free/available connections.""" + if isinstance(pool, BlockingConnectionPool): + # BlockingConnectionPool uses _connections list where created connections are stored + # but we need to get the ones in the queue - these are the free ones + # the uninitialized connections are filtered out + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + # Regular ConnectionPool uses _available_connections for free connections + free_connections = pool._available_connections + else: + raise ValueError(f"Unsupported pool type: {type(pool)}") + + connected_count = 0 + # Validate fields that are validated in the validation of the active connections + for connection in free_connections: + # Validate the same fields as in _validate_in_use_connections_state + assert connection._should_reconnect is False + assert connection.tmp_host_address == tmp_host_address + assert connection.tmp_relax_timeout == relax_timeout + if connection._sock is not None: + connected_count += 1 + + if connected_to_tmp_addres: + assert ( + connection._sock.getpeername()[0] + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + else: + assert ( + connection._sock.getpeername()[0] + == MockSocket.DEFAULT_ADDRESS.split(":")[0] + ) + assert connected_count == should_be_connected_count + + def _validate_all_timeouts(self, expected_timeout): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for mock_socket in self.mock_sockets: + if expected_timeout is None: + assert mock_socket.gettimeout() is None + else: + assert mock_socket.gettimeout() == expected_timeout + + def _validate_conn_kwargs( + self, + pool, + expected_host_address, + expected_port, + expected_tmp_host_address, + expected_tmp_relax_timeout, + ): + """Helper method to validate connection kwargs.""" + assert pool.connection_kwargs["host"] == expected_host_address + assert pool.connection_kwargs["port"] == expected_port + assert pool.connection_kwargs["tmp_host_address"] == expected_tmp_host_address + assert pool.connection_kwargs["tmp_relax_timeout"] == expected_tmp_relax_timeout + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_connection_pool_creation_with_maintenance_events(self, pool_class): """Test that connection pools are created with maintenance events configuration.""" # Create a pool and Redis client with maintenance events max_connections = 3 if pool_class == BlockingConnectionPool else 10 - test_pool, _ = self._get_client(pool_class, max_connections=max_connections) + test_redis_client = self._get_client( + pool_class, max_connections=max_connections + ) + test_pool = test_redis_client.connection_pool try: assert ( @@ -283,7 +425,7 @@ def test_redis_operations_with_mock_sockets(self, pool_class): Basically with test - the mocked socket is validated. """ # Create a pool and Redis client with maintenance events - test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) + test_redis_client = self._get_client(pool_class, max_connections=5) try: # Perform Redis operations that should work with our improved mock responses @@ -300,77 +442,19 @@ def test_redis_operations_with_mock_sockets(self, pool_class): assert len(self.mock_sockets[0].sent_data) >= 2 # HELLO, SET, GET commands # Verify that the connection has maintenance event handler - connection = test_pool.get_connection() - assert hasattr(connection, "_maintenance_event_connection_handler") - test_pool.release(connection) - - finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() - - @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) - def test_multiple_connections_in_pool(self, pool_class): - """Test that multiple connections can be created and used for Redis operations in multiple threads.""" - # Create a pool and Redis client with maintenance events - test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) - - try: - # Results storage for thread operations - results = [] - errors = [] - - def redis_operation(key_suffix): - """Perform Redis operations in a thread.""" - try: - # SET operation - set_result = test_redis_client.set( - f"key{key_suffix}", f"value{key_suffix}" - ) - # GET operation - get_result = test_redis_client.get(f"key{key_suffix}") - results.append((set_result, get_result)) - except Exception as e: - errors.append(e) - - # Run operations in multiple threads to force multiple connections - threads = [] - for i in range(3): - thread = threading.Thread(target=redis_operation, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify no errors occurred - assert len(errors) == 0, f"Errors occurred: {errors}" - - # Verify all operations completed successfully - assert len(results) == 3 - for set_result, get_result in results: - assert set_result is True - assert get_result in [b"value0", b"value1", b"value2"] - - # Verify that multiple connections were created with mock sockets - # With threading, both pool types should create multiple sockets for concurrent access - assert len(self.mock_sockets) >= 2, ( - f"Expected multiple sockets due to threading, got {len(self.mock_sockets)}" - ) - - # Verify each connection has maintenance event handler - connection = test_pool.get_connection() + connection = test_redis_client.connection_pool.get_connection() assert hasattr(connection, "_maintenance_event_connection_handler") - test_pool.release(connection) + test_redis_client.connection_pool.release(connection) finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() def test_pool_handler_with_migrating_event(self): """Test that pool handler correctly handles migrating events.""" # Create a pool and Redis client with maintenance events - test_pool, _ = self._get_client(ConnectionPool) + test_redis_client = self._get_client(ConnectionPool) + test_pool = test_redis_client.connection_pool try: # Create and set a pool handler @@ -421,7 +505,7 @@ def test_migration_related_events_handling_integration(self, pool_class): 8. Uses proper RESP3 push message format for realistic protocol simulation """ # Create a pool and Redis client with maintenance events - test_pool, test_redis_client = self._get_client(pool_class, max_connections=10) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Results storage for thread operations @@ -530,8 +614,8 @@ def redis_operations_with_maintenance_events(thread_id): ) finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migrating_event_with_disabled_relax_timeout(self, pool_class): @@ -551,7 +635,7 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): ) # Create a pool and Redis client with disabled relax timeout config - test_pool, test_redis_client = self._get_client( + test_redis_client = self._get_client( pool_class, max_connections=5, maintenance_events_config=disabled_config ) @@ -635,5 +719,343 @@ def redis_operations_with_disabled_relax(thread_id): ) finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_related_events_handling_integration(self, pool_class): + """ + Test full integration of moving-related events (MOVING) handling with Redis commands. + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(10): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 5 connections to be "in use" + in_use_connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + # the connection used for the command is expected to be reconnected to the new address + # before it is returned to the pool + result2 = test_redis_client.set(key_moving, value_moving) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_moving) failed" + + # Validate pool and connections settings were updated according to MOVING event + # handling expectations + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + # 5 disconnects has happened, 1 of them is with reconnect + self._validate_disconnected(5) + # 5 in use connected + 1 after reconnect + self._validate_connected(6) + self._validate_in_use_connections_state(in_use_connections) + # Validate there is 1 free connection that is connected + # the one that has handled the MOVING should reconnect after parsing the response + self._validate_free_connections_state( + test_redis_client.connection_pool, + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + + # Wait for MOVING timeout to expire and the moving completed handler to run + print("Waiting for MOVING timeout to expire...") + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + + self._validate_all_timeouts(None) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + None, + -1, + ) + self._validate_free_connections_state( + test_redis_client.connection_pool, + None, + -1, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_while_moving_not_expired(self, pool_class): + """ + Test creating new connections while MOVING event is active (not expired). + + This test validates that: + 1. After MOVING event is processed, new connections are created with temporary address + 2. New connections inherit the relaxed timeout settings + 3. Pool configuration is properly applied to newly created connections + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + + # Now get several more connections to force creation of new ones + # This should create new connections with the temporary address + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with temporary address and relax timeout + # and when connecting those configs are used + # get_connection() returns a connection that is already connected + assert ( + new_connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection.tmp_relax_timeout == self.config.relax_timeout + # New connections should be connected to the temporary address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + assert ( + new_connection._sock.getpeername()[0] + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection._sock.gettimeout() == self.config.relax_timeout + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_after_moving_expires(self, pool_class): + """ + Test creating new connections after MOVING event expires. + + This test validates that: + 1. After MOVING timeout expires, new connections use original address + 2. Pool configuration is reset to original values + 3. New connections don't inherit temporary settings + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Wait for MOVING timeout to expire + print("Waiting for MOVING timeout to expire...") + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + + # Now get several new connections after expiration + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with original address (no temporary settings) + assert new_connection.tmp_host_address is None + assert new_connection.tmp_relax_timeout == -1 + # New connections should be connected to the original address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + # Socket timeout should be None (original timeout) + assert new_connection._sock.gettimeout() is None + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_receive_migrated_after_moving(self, pool_class): + # TODO Refactor: when migrated comes after moving and + # moving hasn't yet expired - it should not decrease timeouts + """ + Test receiving MIGRATED event after MOVING event. + + This test validates the complete MOVING -> MIGRATED lifecycle: + 1. MOVING event is processed and temporary settings are applied + 2. MIGRATED event is received during command execution + 3. Temporary settings are cleared after MIGRATED + 4. Pool configuration is restored to original values + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Step 1: Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result_moving = test_redis_client.set(key_moving, value_moving) + + # Validate MOVING command result + assert result_moving is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + + # Step 2: Run command that will receive and handle MIGRATED event + # This should clear the temporary settings + key_migrated = "key_receive_migrated_0" + value_migrated = "migrated_value" + result_migrated = test_redis_client.set(key_migrated, value_migrated) + + # Validate MIGRATED command result + assert result_migrated is True, "SET key_receive_migrated command failed" + + # Step 3: Validate that MIGRATED event was processed but MOVING settings remain + # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[ + 0 + ], # MOVING settings still active + self.config.relax_timeout, # MOVING timeout still active + ) + + # Step 4: Create new connections after MIGRATED to verify they still use MOVING settings + # (since MOVING settings are still active) + new_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + new_connections.append(connection) + + # Validate that new connections are created with MOVING settings (still active) + for connection in new_connections: + assert ( + connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + # Note: New connections may not inherit the exact relax timeout value + # but they should have the temporary host address + # New connections should be connected + if connection._sock is not None: + assert connection._sock.connected is True + + # Release the new connections + for connection in new_connections: + test_redis_client.connection_pool.release(connection) + + # Validate free connections state with MOVING settings still active + # Note: We'll validate with the pool's current settings rather than individual connection settings + # since new connections may have different timeout values but still use the temporary address + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() From c0438c8b439ad69467c4ca6689fd4bcec6439da8 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:28:20 +0300 Subject: [PATCH 08/28] Fixed BlockingConnectionPool locking strategy. Removed debug logging. Refactored the maintenance events tests not to be multithreaded - we don't need it for those tests. --- redis/asyncio/connection.py | 2 + redis/client.py | 6 - redis/connection.py | 125 ++++--- redis/maintenance_events.py | 36 +- tests/test_connection_pool.py | 3 + tests/test_maintenance_events_handling.py | 419 ++++++++++++---------- 6 files changed, 329 insertions(+), 262 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4efd868f6f..fe86e4c36e 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1308,6 +1308,8 @@ def __init__( ) self._condition = asyncio.Condition() self.timeout = timeout + self._in_maintenance = False + self._locked = False @deprecated_args( args_to_warn=["*"], diff --git a/redis/client.py b/redis/client.py index 0ec36c52d9..473b1e00f2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -668,9 +668,6 @@ def _execute_command(self, *args, **options): finally: if conn and conn.should_reconnect(): - logging.debug( - f"***** Redis reconnect before exit _execute_command --> notification for {conn._sock.getpeername()}" - ) self._close_connection(conn) conn.connect() if self._single_connection_client: @@ -963,9 +960,6 @@ def _execute(self, conn, command, *args, **kwargs): lambda _: self._reconnect(conn), ) if conn.should_reconnect(): - logging.debug( - f"***** PubSub --> Reconnect on notification for {conn._sock.getpeername()}" - ) self._reconnect(conn) return response diff --git a/redis/connection.py b/redis/connection.py index 81a80d0903..a096b045b2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -431,7 +431,20 @@ def set_parser(self, parser_class): def set_maintenance_event_pool_handler( self, maintenance_event_pool_handler: MaintenanceEventPoolHandler ): - self._parser.set_node_moving_push_handler(maintenance_event_pool_handler) + self._parser.set_node_moving_push_handler( + maintenance_event_pool_handler.handle_event + ) + + # Initialize maintenance event connection handler if it doesn't exist + if not hasattr(self, "_maintenance_event_connection_handler"): + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler( + self, maintenance_event_pool_handler.config + ) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) def connect(self): "Connects to the Redis server if not already connected" @@ -802,10 +815,6 @@ def should_reconnect(self): def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): if self._sock: timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout - logging.debug( - f"***** Connection --> Updating timeout for {self._sock.getpeername()}" - f" to timeout {timeout}; relax_timeout: {relax_timeout}" - ) self._sock.settimeout(timeout) self.update_parser_buffer_timeout(timeout) @@ -858,10 +867,6 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None - if self.tmp_host_address is not None: - logging.debug( - f"***** Connection --> Using tmp_host_address: {self.tmp_host_address}" - ) host = self.tmp_host_address or self.host for res in socket.getaddrinfo( @@ -882,14 +887,8 @@ def _connect(self): # set the socket_connect_timeout before we connect if self.tmp_relax_timeout != -1: - logging.debug( - f"***** Connection connect --> Using relax_timeout: {self.tmp_relax_timeout}" - ) sock.settimeout(self.tmp_relax_timeout) else: - logging.debug( - f"***** Connection connect --> Using default socket_connect_timeout: {self.socket_connect_timeout}" - ) sock.settimeout(self.socket_connect_timeout) # connect @@ -897,16 +896,9 @@ def _connect(self): # set the socket_timeout now that we're connected if self.tmp_relax_timeout != -1: - logging.debug( - f"***** Connection --> Using relax_timeout: {self.tmp_relax_timeout}" - ) sock.settimeout(self.tmp_relax_timeout) else: - logging.debug( - f"***** Connection --> Using default socket_timeout: {self.socket_timeout}" - ) sock.settimeout(self.socket_timeout) - logging.debug(f"Connected to {sock.getpeername()}") return sock except OSError as _: @@ -1606,14 +1598,10 @@ def _update_maintenance_events_configs_for_connections( ): with self._lock: for conn in self._available_connections: - conn.set_maintenance_events_pool_handler( - maintenance_events_pool_handler - ) + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) conn.maintenance_events_config = maintenance_events_pool_handler.config for conn in self._in_use_connections: - conn.set_maintenance_events_pool_handler( - maintenance_events_pool_handler - ) + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) conn.maintenance_events_config = maintenance_events_pool_handler.config def reset(self) -> None: @@ -1755,9 +1743,6 @@ def release(self, connection: "Connection") -> None: if self.owns_connection(connection): if connection.should_reconnect(): - logging.debug( - f"***** Pool--> disconnecting in release {connection._sock.getpeername()}" - ) connection.disconnect() self._available_connections.append(connection) self._event_dispatcher.dispatch( @@ -1917,9 +1902,6 @@ def update_connections_current_timeout( If -1 is provided - the relax timeout is disabled. :param include_available_connections: Whether to include available connections in the update. """ - logging.debug(f"***** Pool --> Updating timeouts. New value: {relax_timeout}") - start_time = time.time() - for conn in self._in_use_connections: self._update_connection_timeout(conn, relax_timeout) @@ -1927,11 +1909,6 @@ def update_connections_current_timeout( for conn in self._available_connections: self._update_connection_timeout(conn, relax_timeout) - execution_time_us = (time.time() - start_time) * 1000000 - logging.error( - f"###### TIMEOUTS execution time: {execution_time_us:.0f} microseconds" - ) - def _update_connection_for_reconnect( self, connection: "Connection", @@ -2021,6 +1998,8 @@ def __init__( ): self.queue_class = queue_class self.timeout = timeout + self._in_maintenance = False + self._locked = False super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -2029,7 +2008,10 @@ def __init__( def reset(self): # Create and fill up a thread safe queue with ``None`` values. - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True self.pool = self.queue_class(self.max_connections) while True: try: @@ -2040,6 +2022,13 @@ def reset(self): # Keep a list of actual connection instances so that we can # disconnect them later. self._connections = [] + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -2054,7 +2043,10 @@ def reset(self): def make_connection(self): "Make a fresh connection." - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True if self.cache is not None: connection = CacheProxyConnection( self.connection_class(**self.connection_kwargs), @@ -2066,6 +2058,13 @@ def make_connection(self): self._connections.append(connection) return connection + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False @deprecated_args( args_to_warn=["*"], @@ -2090,7 +2089,10 @@ def get_connection(self, command_name=None, *keys, **options): # Try and get a connection from the pool. If one isn't available within # self.timeout then raise a ``ConnectionError``. connection = None - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True try: connection = self.pool.get(block=True, timeout=self.timeout) except Empty: @@ -2102,6 +2104,13 @@ def get_connection(self, command_name=None, *keys, **options): # a new connection to add to the pool. if connection is None: connection = self.make_connection() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False try: # ensure this connection is connected to Redis @@ -2130,7 +2139,10 @@ def release(self, connection): # Make sure we haven't changed process. self._checkpid() - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True if not self.owns_connection(connection): # pool doesn't own this connection. do not add it back # to the pool. instead add a None value which is a placeholder @@ -2140,24 +2152,39 @@ def release(self, connection): self.pool.put_nowait(None) return if connection.should_reconnect(): - logging.debug( - f"***** Blocking Pool--> disconnecting in release {connection._sock.getpeername()}" - ) connection.disconnect() # Put the connection back into the pool. try: + print("Releasing connection - in the pool") self.pool.put_nowait(connection) except Full: # perhaps the pool has been reset() after a fork? regardless, # we don't want this connection pass + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False def disconnect(self): "Disconnects all connections in the pool." self._checkpid() - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True for connection in self._connections: connection.disconnect() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False def update_active_connections_for_reconnect( self, @@ -2236,3 +2263,7 @@ def _update_maintenance_events_configs_for_connections( conn.maintenance_events_config = ( maintenance_events_pool_handler.config ) + + def set_in_maintenance(self, in_maintenance: bool): + """Set the maintenance mode for the connection pool.""" + self._in_maintenance = in_maintenance diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 5b2091472d..bf0cd6bda8 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -2,12 +2,16 @@ import threading import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from redis.typing import Number if TYPE_CHECKING: - from redis.connection import ConnectionInterface, ConnectionPool + from redis.connection import ( + BlockingConnectionPool, + ConnectionInterface, + ConnectionPool, + ) class MaintenanceEvent(ABC): @@ -303,7 +307,11 @@ def is_relax_timeouts_enabled(self) -> bool: class MaintenanceEventPoolHandler: - def __init__(self, pool: "ConnectionPool", config: MaintenanceEventsConfig) -> None: + def __init__( + self, + pool: Union["ConnectionPool", "BlockingConnectionPool"], + config: MaintenanceEventsConfig, + ) -> None: self.pool = pool self.config = config self._processed_events = set() @@ -334,18 +342,15 @@ def handle_node_moving_event(self, event: NodeMovingEvent): # nothing to do in the connection pool handling # the event has already been handled or is expired # just return - logging.debug("***** MOVING --> SKIPPED DONE") return - logging.info(f"***** MOVING --> {event}") - logging.info(f"***** MOVING --> set: {self._processed_events}") - start_time = time.time() - with self.pool._lock: if ( self.config.proactive_reconnect or self.config.is_relax_timeouts_enabled() ): + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(True) # edit the config for new connections until the notification expires self.pool.update_connection_kwargs_with_tmp_settings( tmp_host_address=event.new_node_host, @@ -371,21 +376,14 @@ def handle_node_moving_event(self, event: NodeMovingEvent): tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, ) - execution_time_us = (time.time() - start_time_2) * 1000000 - logging.error( - f"###### MOVING disconnects execution time: {execution_time_us:.0f} microseconds" - ) + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(False) threading.Timer(event.ttl, self.handle_node_moved_event).start() self._processed_events.add(event) - execution_time_us = (time.time() - start_time) * 1000000 - logging.error( - f"###### MOVING total execution time: {execution_time_us:.0f} microseconds" - ) def handle_node_moved_event(self): - logging.debug("***** MOVING END--> Starting to revert the changes.") with self._lock: self.pool.update_connection_kwargs_with_tmp_settings( tmp_host_address=None, @@ -397,12 +395,10 @@ def handle_node_moved_event(self): self.pool.update_connections_current_timeout( relax_timeout=-1, include_free_connections=True ) - logging.debug("***** MOVING END--> TIMEOUTS RESET") self.pool.update_connections_tmp_settings( tmp_host_address=None, tmp_relax_timeout=-1 ) - logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") class MaintenanceEventConnectionHandler: @@ -424,7 +420,6 @@ def handle_migrating_event(self, notification: NodeMigratingEvent): if not self.config.is_relax_timeouts_enabled(): return - logging.info(f"***** MIGRATING --> {notification}") # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) @@ -433,7 +428,6 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): if not self.config.is_relax_timeouts_enabled(): return - logging.info(f"***** MIGRATED --> {notification}") # Node migration completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 3a4896f2a3..4518cd7290 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -33,6 +33,9 @@ def connect(self): def can_read(self): return False + def should_reconnect(self): + return False + class TestConnectionPool: def get_pool( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index c04c6c066e..1620471ea7 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -52,12 +52,12 @@ def send(self, data): response = b"+OK\r\n" # Check if this is a key that should trigger a push message - if b"key_receive_migrating_" in data: + if b"key_receive_migrating_" in data or b"key_receive_migrating" in data: # MIGRATING push message before SET key_receive_migrating_X response # Format: >2\r\n$9\r\nMIGRATING\r\n:10\r\n (2 elements: MIGRATING, ttl) migrating_push = ">2\r\n$9\r\nMIGRATING\r\n:10\r\n" response = migrating_push.encode() + response - elif b"key_receive_migrated_" in data: + elif b"key_receive_migrated_" in data or b"key_receive_migrated" in data: # MIGRATED push message before SET key_receive_migrated_X response # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" @@ -75,32 +75,17 @@ def send(self, data): if b"hello" in data: response = b"$5\r\nworld\r\n" self.pending_responses.append(response) - # Handle thread-specific keys for integration test first (more specific) - elif b"key1_0" in data: - self.pending_responses.append(b"$8\r\nvalue1_0\r\n") - elif b"key_receive_migrating_0" in data: - self.pending_responses.append(b"$8\r\nvalue2_0\r\n") + # Handle specific keys used in tests elif b"key_receive_moving_0" in data: self.pending_responses.append(b"$8\r\nvalue3_0\r\n") - elif b"key1_1" in data: - self.pending_responses.append(b"$8\r\nvalue1_1\r\n") - elif b"key_receive_migrating_1" in data: - self.pending_responses.append(b"$8\r\nvalue2_1\r\n") - elif b"key_receive_moving_1" in data: - self.pending_responses.append(b"$8\r\nvalue3_1\r\n") - elif b"key1_2" in data: - self.pending_responses.append(b"$8\r\nvalue1_2\r\n") - elif b"key_receive_migrating_2" in data: - self.pending_responses.append(b"$8\r\nvalue2_2\r\n") - elif b"key_receive_moving_2" in data: - self.pending_responses.append(b"$8\r\nvalue3_2\r\n") - # Generic keys (less specific, should come after thread-specific) - elif b"key0" in data: - self.pending_responses.append(b"$6\r\nvalue0\r\n") + elif b"key_receive_migrated_0" in data: + self.pending_responses.append(b"$13\r\nmigrated_value\r\n") + elif b"key_receive_migrating" in data: + self.pending_responses.append(b"$6\r\nvalue2\r\n") + elif b"key_receive_migrated" in data: + self.pending_responses.append(b"$6\r\nvalue3\r\n") elif b"key1" in data: self.pending_responses.append(b"$6\r\nvalue1\r\n") - elif b"key2" in data: - self.pending_responses.append(b"$6\r\nvalue2\r\n") else: self.pending_responses.append(b"$-1\r\n") # NULL response else: @@ -260,7 +245,37 @@ def _get_client( return test_redis_client - def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): + def _validate_connection_handlers(self, conn, pool_handler, config): + """Helper method to validate connection handlers are properly set.""" + # Test that the node moving handler function is correctly set + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is config + + def _validate_current_timeout_for_thread( + self, thread_id, expected_timeout, error_msg=None + ): """Helper method to validate the current timeout for the calling thread.""" actual_timeout = None # Get the actual thread ID from the current thread @@ -271,9 +286,27 @@ def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): break assert actual_timeout == expected_timeout, ( + error_msg, f"Thread {thread_id}: Expected timeout ({expected_timeout}), " f"but found timeout: {actual_timeout} for thread {thread_id}. " - f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", + ) + + def _validate_current_timeout(self, expected_timeout, error_msg=None): + """Helper method to validate the current timeout for the calling thread.""" + actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + f"{error_msg or ''}" + f"Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", ) def _validate_disconnected(self, expected_count): @@ -378,6 +411,95 @@ def _validate_conn_kwargs( assert pool.connection_kwargs["tmp_host_address"] == expected_tmp_host_address assert pool.connection_kwargs["tmp_relax_timeout"] == expected_tmp_relax_timeout + def test_client_initialization(self): + """Test that Redis client is created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + pool_handler = test_redis_client.connection_pool.connection_kwargs.get( + "maintenance_events_pool_handler" + ) + assert pool_handler is not None + assert pool_handler.config == self.config + + conn = test_redis_client.connection_pool.get_connection() + assert conn._should_reconnect is False + assert conn.tmp_host_address is None + assert conn.tmp_relax_timeout == -1 + + # Test that the node moving handler function is correctly set by + # comparing the underlying function and instance + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is self.config + + def test_maint_handler_init_for_existing_connections(self): + """Test that maintenance event handlers are properly set on existing and new connections + when configuration is enabled after client creation.""" + + # Create a Redis client with disabled maintenance events configuration + disabled_config = MaintenanceEventsConfig(enabled=False) + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=disabled_config, + ) + + # Extract an existing connection before enabling maintenance events + existing_conn = test_redis_client.connection_pool.get_connection() + + # Verify that maintenance events are initially disabled + assert existing_conn._parser.node_moving_push_handler_func is None + assert not hasattr(existing_conn, "_maintenance_event_connection_handler") + assert existing_conn._parser.maintenance_push_handler_func is None + + # Create a new enabled configuration and set up pool handler + enabled_config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, enabled_config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + # Validate the existing connection after enabling maintenance events + # Both existing and new connections should now have full handler setup + self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) + + # Create a new connection and validate it has full handlers + new_conn = test_redis_client.connection_pool.get_connection() + self._validate_connection_handlers(new_conn, pool_handler, enabled_config) + + # Clean up connections + test_redis_client.connection_pool.release(existing_conn) + test_redis_client.connection_pool.release(new_conn) + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_connection_pool_creation_with_maintenance_events(self, pool_class): """Test that connection pools are created with maintenance events configuration.""" @@ -492,14 +614,14 @@ def test_pool_handler_with_migrating_event(self): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migration_related_events_handling_integration(self, pool_class): """ - Test full integration of migration-related events (MIGRATING/MIGRATED) handling with multiple threads and commands. + Test full integration of migration-related events (MIGRATING/MIGRATED) handling. This test validates the complete migration lifecycle: - 1. Creates 3 concurrent threads, each executing 5 Redis commands - 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating_X) + 1. Executes 5 Redis commands sequentially + 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating) 3. Validates socket timeout is updated to relaxed value (30s) after MIGRATING 4. Executes commands 3-4 while timeout remains relaxed - 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated_X) + 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated) 6. Validates socket timeout is restored after MIGRATED 7. Tests both ConnectionPool and BlockingConnectionPool implementations 8. Uses proper RESP3 push message format for realistic protocol simulation @@ -508,107 +630,63 @@ def test_migration_related_events_handling_integration(self, pool_class): test_redis_client = self._get_client(pool_class, max_connections=10) try: - # Results storage for thread operations - results = [] - errors = [] - - def redis_operations_with_maintenance_events(thread_id): - """Perform Redis operations with maintenance events in a thread.""" - try: - # Command 1: Initial command - key1 = f"key1_{thread_id}" - value1 = f"value1_{thread_id}" - result1 = test_redis_client.set(key1, value1) - - # Validate Command 1 result - erros_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" - assert result1 is True, erros_msg - - # Command 2: This SET command will receive MIGRATING push message before response - key_migrating = f"key_receive_migrating_{thread_id}" - value_migrating = f"value2_{thread_id}" - result2 = test_redis_client.set(key_migrating, value_migrating) - - # Validate Command 2 result - erros_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" - assert result2 is True, erros_msg - - # Step 4: Validate timeout was updated to relaxed value after MIGRATING - self._validate_current_timeout_for_thread(thread_id, 30) - - # Command 3: Another command while timeout is still relaxed - result3 = test_redis_client.get(key1) - - # Validate Command 3 result - expected_value3 = value1.encode() - errors_msg = ( - f"Thread {thread_id}: Command 3 (GET key1) failed. " - f"Expected {expected_value3}, got {result3}" - ) - assert result3 == expected_value3, errors_msg + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) - # Command 4: Execute command (step 5) - result4 = test_redis_client.get(key_migrating) + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" - # Validate Command 4 result - expected_value4 = value_migrating.encode() - errors_msg = ( - f"Thread {thread_id}: Command 4 (GET key_receive_migrating) failed. " - f"Expected {expected_value4}, got {result4}" - ) - assert result4 == expected_value4, errors_msg + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) - # Step 6: Validate socket timeout is still relaxed during commands 3-4 - self._validate_current_timeout_for_thread(thread_id, 30) + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" - # Command 5: This SET command will receive - # MIGRATED push message before actual response - key_migrated = f"key_receive_migrated_{thread_id}" - value_migrated = f"value3_{thread_id}" - result5 = test_redis_client.set(key_migrated, value_migrated) + # Step 4: Validate timeout was updated to relaxed value after MIGRATING + self._validate_current_timeout(30, "Right after MIGRATING is received. ") - # Validate Command 5 result - errors_msg = f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" - assert result5 is True, errors_msg + # Command 3: Another command while timeout is still relaxed + result3 = test_redis_client.get(key1) - # Step 8: Validate socket timeout is reversed back to original after MIGRATED - self._validate_current_timeout_for_thread(thread_id, None) + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected {expected_value3}, got {result3}" + ) - results.append( - { - "thread_id": thread_id, - "success": True, - } - ) + # Command 4: Execute command (step 5) + result4 = test_redis_client.get(key_migrating) - except Exception as e: - errors.append(f"Thread {thread_id}: {e}") + # Validate Command 4 result + expected_value4 = value_migrating.encode() + assert result4 == expected_value4, ( + f"Command 4 (GET key_receive_migrating) failed. Expected {expected_value4}, got {result4}" + ) - # Run operations in multiple threads (step 1) - threads = [] - for i in range(3): - thread = threading.Thread( - target=redis_operations_with_maintenance_events, - args=(i,), - name=str(i), - ) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify all threads completed successfully - successful_threads = len(results) - assert successful_threads == 3, ( - f"Expected 3 successful threads, got {successful_threads}. " - f"Errors: {errors}" + # Step 6: Validate socket timeout is still relaxed during commands 3-4 + self._validate_current_timeout( + 30, + "Execute a command with a connection extracted from the pool (after it has received MIGRATING)", ) - # Verify maintenance events were processed correctly across all threads - # Note: Different pool types may create different numbers of sockets - # The key is that we have at least 1 socket and all threads succeeded + # Command 5: This SET command will receive + # MIGRATED push message before actual response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result5 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_migrated) failed" + + # Step 8: Validate socket timeout is reversed back to original after MIGRATED + self._validate_current_timeout(None) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" ) @@ -640,80 +718,37 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): ) try: - # Results storage for thread operations - results = [] - errors = [] - - def redis_operations_with_disabled_relax(thread_id): - """Perform Redis operations with disabled relax timeout in a thread.""" - try: - # Command 1: Initial command - key1 = f"key1_{thread_id}" - value1 = f"value1_{thread_id}" - result1 = test_redis_client.set(key1, value1) - - # Validate Command 1 result - errors_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" - assert result1 is True, errors_msg - - # Command 2: This SET command will receive MIGRATING push message before response - key_migrating = f"key_receive_migrating_{thread_id}" - value_migrating = f"value2_{thread_id}" - result2 = test_redis_client.set(key_migrating, value_migrating) - - # Validate Command 2 result - errors_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" - assert result2 is True, errors_msg - - # Validate timeout was NOT updated (relax is disabled) - # Should remain at default timeout (None), not relaxed to 30s - self._validate_current_timeout_for_thread(thread_id, None) - - # Command 3: Another command to verify timeout remains unchanged - result3 = test_redis_client.get(key1) - - # Validate Command 3 result - expected_value3 = value1.encode() - errors_msg = ( - f"Thread {thread_id}: Command 3 (GET key1) failed. " - f"Expected: {expected_value3}, Got: {result3}" - ) - assert result3 == expected_value3, errors_msg + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) - results.append( - { - "thread_id": thread_id, - "success": True, - } - ) + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" - except Exception as e: - errors.append(f"Thread {thread_id}: {str(e)}") + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) - # Run operations in multiple threads to test concurrent behavior - threads = [] - for i in range(3): - thread = threading.Thread( - target=redis_operations_with_disabled_relax, args=(i,) - ) - threads.append(thread) - thread.start() + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" - # Wait for all threads to complete - for thread in threads: - thread.join() + # Validate timeout was NOT updated (relax is disabled) + # Should remain at default timeout (None), not relaxed to 30s + self._validate_current_timeout(None) - # Verify no errors occurred - assert len(errors) == 0, f"Errors occurred: {errors}" + # Command 3: Another command to verify timeout remains unchanged + result3 = test_redis_client.get(key1) - # Verify all operations completed successfully - assert len(results) == 3, ( - f"Expected 3 successful threads, got {len(results)}" + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" ) - # Verify maintenance events were processed correctly across all threads - # Note: Different pool types may create different numbers of sockets - # The key is that we have at least 1 socket and all threads succeeded + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" ) @@ -726,6 +761,13 @@ def redis_operations_with_disabled_relax(thread_id): def test_moving_related_events_handling_integration(self, pool_class): """ Test full integration of moving-related events (MOVING) handling with Redis commands. + + This test validates the complete MOVING event lifecycle: + 1. Creates multiple connections in the pool + 2. Executes a Redis command that triggers a MOVING push message + 3. Validates that pool configuration is updated with temporary address and timeout + 4. Validates that existing connections are marked for disconnection + 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ # Create a pool and Redis client with maintenance events and pool handler test_redis_client = self._get_client( @@ -956,8 +998,6 @@ def test_create_new_conn_after_moving_expires(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_receive_migrated_after_moving(self, pool_class): - # TODO Refactor: when migrated comes after moving and - # moving hasn't yet expired - it should not decrease timeouts """ Test receiving MIGRATED event after MOVING event. @@ -966,6 +1006,9 @@ def test_receive_migrated_after_moving(self, pool_class): 2. MIGRATED event is received during command execution 3. Temporary settings are cleared after MIGRATED 4. Pool configuration is restored to original values + + Note: When MIGRATED comes after MOVING and MOVING hasn't yet expired, + it should not decrease timeouts (future refactoring consideration). """ # Create a pool and Redis client with maintenance events and pool handler test_redis_client = self._get_client( From 6ca514f7ba255442076b68ff2c72f1fc588287dd Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:38:30 +0300 Subject: [PATCH 09/28] Fixing linters --- redis/client.py | 1 - redis/maintenance_events.py | 1 - 2 files changed, 2 deletions(-) diff --git a/redis/client.py b/redis/client.py index 473b1e00f2..a1a053ddc6 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,5 +1,4 @@ import copy -import logging import re import threading import time diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index bf0cd6bda8..5fe2e81de8 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -371,7 +371,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent): # take care for the inactive connections in the pool # delete them and create new ones - start_time_2 = time.time() self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, From 778abdf0d9606d71f0b9449fb59e5ea31365a941 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:42:41 +0300 Subject: [PATCH 10/28] Applying Copilot's comments --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index a1a053ddc6..a6c96c3882 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1662,7 +1662,7 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: - # in reset() the connection is diconnected before returned to the pool if + # in reset() the connection is disconnected before returned to the pool if # it is marked for reconnect. self.reset() From 667109be2c6e61ba8abd43e23a69a9492292e869 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:49:20 +0300 Subject: [PATCH 11/28] Fixed type annotations not compatible with older python versions --- redis/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index a096b045b2..57e8869e40 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -824,8 +824,8 @@ def update_parser_buffer_timeout(self, timeout: Optional[float] = None): def update_tmp_settings( self, - tmp_host_address: Optional[str | object] = SENTINEL, - tmp_relax_timeout: Optional[float | object] = SENTINEL, + tmp_host_address: Optional[Union[str, object]] = SENTINEL, + tmp_relax_timeout: Optional[Union[float, object]] = SENTINEL, ): """ The value of SENTINEL is used to indicate that the property should not be updated. From ef1742a111a0700e914f5fd596696ed313f845cb Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 19:12:07 +0300 Subject: [PATCH 12/28] Add a few more tests and fix pool mock for python 3.9 --- redis/connection.py | 1 - tests/test_maintenance_events.py | 8 +- tests/test_maintenance_events_handling.py | 126 ++++++++++++++++++++++ 3 files changed, 131 insertions(+), 4 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 57e8869e40..7e9ad95b21 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2155,7 +2155,6 @@ def release(self, connection): connection.disconnect() # Put the connection back into the pool. try: - print("Releasing connection - in the pool") self.pool.put_nowait(connection) except Full: # perhaps the pool has been reset() after a fork? regardless, diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index 69a6014fe1..ac7d10b51e 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -1,5 +1,6 @@ import threading -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock +import pytest from redis.maintenance_events import ( MaintenanceEvent, @@ -17,8 +18,6 @@ class TestMaintenanceEvent: def test_abstract_class_cannot_be_instantiated(self): """Test that MaintenanceEvent cannot be instantiated directly.""" - import pytest - with patch("time.monotonic", return_value=1000): with pytest.raises(TypeError): MaintenanceEvent(id=1, ttl=10) # type: ignore @@ -347,6 +346,9 @@ class TestMaintenanceEventPoolHandler: def setup_method(self): """Set up test fixtures.""" self.mock_pool = Mock() + self.mock_pool._lock = MagicMock() + self.mock_pool._lock.__enter__.return_value = None + self.mock_pool._lock.__exit__.return_value = None self.config = MaintenanceEventsConfig( enabled=True, proactive_reconnect=True, relax_timeout=20 ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 1620471ea7..6ce74ebc3b 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1102,3 +1102,129 @@ def test_receive_migrated_after_moving(self, pool_class): finally: if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_overlapping_moving_events(self, pool_class): + """ + Test handling of overlapping/duplicate MOVING events (e.g., two MOVING events before the first expires). + Ensures that the second MOVING event updates the pool and connections as expected, and that expiry/cleanup works. + """ + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + try: + # Create and release some connections + for _ in range(3): + conn = test_redis_client.connection_pool.get_connection() + test_redis_client.connection_pool.release(conn) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = test_redis_client.connection_pool.get_connection() + in_use_connections.append(conn) + + # Trigger first MOVING event + key_moving1 = "key_receive_moving_0" + value_moving1 = "value3_0" + result1 = test_redis_client.set(key_moving1, value_moving1) + assert result1 is True + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + # Validate all connections reflect the first MOVING event + self._validate_in_use_connections_state(in_use_connections) + self._validate_free_connections_state( + test_redis_client.connection_pool, + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + + # Before the first MOVING expires, trigger a second MOVING event (simulate new address) + # Patch MockSocket to use a new address for the second event + new_address = "5.6.7.8:6380" + orig_after_moving = MockSocket.AFTER_MOVING_ADDRESS + MockSocket.AFTER_MOVING_ADDRESS = new_address + try: + key_moving2 = "key_receive_moving_1" + value_moving2 = "value3_1" + result2 = test_redis_client.set(key_moving2, value_moving2) + assert result2 is True + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + new_address.split(":")[0], + self.config.relax_timeout, + ) + # Validate all connections reflect the second MOVING event + self._validate_in_use_connections_state(in_use_connections) + self._validate_free_connections_state( + test_redis_client.connection_pool, + new_address.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + finally: + MockSocket.AFTER_MOVING_ADDRESS = orig_after_moving + + # Wait for both MOVING timeouts to expire + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + None, + -1, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_thread_safety_concurrent_event_handling(self, pool_class): + """ + Test thread-safety under concurrent maintenance event handling. + Simulates multiple threads triggering MOVING events and performing operations concurrently. + """ + import threading + + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + results = [] + errors = [] + + def worker(idx): + try: + key = f"key_receive_moving_{idx}" + value = f"value3_{idx}" + result = test_redis_client.set(key, value) + results.append(result) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert all(results), f"Not all threads succeeded: {results}" + assert not errors, f"Errors occurred in threads: {errors}" + # After all threads, MOVING event should have been handled safely + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() From 7b4389027441607cb20549a41e31108a7265efc9 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 18 Jul 2025 16:14:01 +0300 Subject: [PATCH 13/28] Adding maintenance state to connections. Migrating and Migrated are not processed in in Moving state. Tests are updated --- redis/_parsers/hiredis.py | 1 + redis/connection.py | 47 +++++- redis/maintenance_events.py | 32 +++- tests/test_connection_pool.py | 10 +- tests/test_maintenance_events_handling.py | 197 +++++++++++++++++----- 5 files changed, 237 insertions(+), 50 deletions(-) diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index e9df314a8c..d82fe99cd9 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -152,6 +152,7 @@ def read_response(self, disable_decoding=False, push_request=False): disable_decoding=disable_decoding, push_request=push_request, ) + return response if disable_decoding: response = self._reader.gets(False) diff --git a/redis/connection.py b/redis/connection.py index 7e9ad95b21..5646a745af 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -42,6 +42,7 @@ MaintenanceEventConnectionHandler, MaintenanceEventPoolHandler, MaintenanceEventsConfig, + MaintenanceState, ) from .retry import Retry from .utils import ( @@ -285,6 +286,7 @@ def __init__( maintenance_events_config: Optional[MaintenanceEventsConfig] = None, tmp_host_address: Optional[str] = None, tmp_relax_timeout: Optional[float] = -1, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, ): """ Initialize a new Connection. @@ -374,6 +376,7 @@ def __init__( self._should_reconnect = False self.tmp_host_address = tmp_host_address self.tmp_relax_timeout = tmp_relax_timeout + self.maintenance_state = maintenance_state def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -835,6 +838,9 @@ def update_tmp_settings( if tmp_relax_timeout is not SENTINEL: self.tmp_relax_timeout = tmp_relax_timeout + def set_maintenance_state(self, state: "MaintenanceState"): + self.maintenance_state = state + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -1724,11 +1730,18 @@ def make_connection(self) -> "ConnectionInterface": raise MaxConnectionsError("Too many connections") self._created_connections += 1 + # Pass current maintenance_state to new connections + maintenance_state = self.connection_kwargs.get( + "maintenance_state", MaintenanceState.NONE + ) + kwargs = dict(self.connection_kwargs) + kwargs["maintenance_state"] = maintenance_state + if self.cache is not None: return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock + self.connection_class(**kwargs), self.cache, self._lock ) - return self.connection_class(**self.connection_kwargs) + return self.connection_class(**kwargs) def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" @@ -1953,6 +1966,16 @@ async def _mock(self, error: RedisError): """ pass + def set_maintenance_state_for_all(self, state: "MaintenanceState"): + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_state(state) + for conn in self._in_use_connections: + conn.set_maintenance_state(state) + + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state + class BlockingConnectionPool(ConnectionPool): """ @@ -2047,15 +2070,20 @@ def make_connection(self): if self._in_maintenance: self._lock.acquire() self._locked = True + # Pass current maintenance_state to new connections + maintenance_state = self.connection_kwargs.get( + "maintenance_state", MaintenanceState.NONE + ) + kwargs = dict(self.connection_kwargs) + kwargs["maintenance_state"] = maintenance_state if self.cache is not None: connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), + self.connection_class(**kwargs), self.cache, self._lock, ) else: - connection = self.connection_class(**self.connection_kwargs) - + connection = self.connection_class(**kwargs) self._connections.append(connection) return connection finally: @@ -2266,3 +2294,12 @@ def _update_maintenance_events_configs_for_connections( def set_in_maintenance(self, in_maintenance: bool): """Set the maintenance mode for the connection pool.""" self._in_maintenance = in_maintenance + + def set_maintenance_state_for_all(self, state: "MaintenanceState"): + with self._lock: + for conn in getattr(self, "_connections", []): + if conn: + conn.set_maintenance_state(state) + + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 5fe2e81de8..dd62602105 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -1,3 +1,4 @@ +import enum import logging import threading import time @@ -6,6 +7,13 @@ from redis.typing import Number + +class MaintenanceState(enum.Enum): + NONE = "none" + MOVING = "moving" + MIGRATING = "migrating" + + if TYPE_CHECKING: from redis.connection import ( BlockingConnectionPool, @@ -351,6 +359,9 @@ def handle_node_moving_event(self, event: NodeMovingEvent): ): if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(True) + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING) + self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING) # edit the config for new connections until the notification expires self.pool.update_connection_kwargs_with_tmp_settings( tmp_host_address=event.new_node_host, @@ -368,7 +379,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent): tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, ) - # take care for the inactive connections in the pool # delete them and create new ones self.pool.disconnect_and_reconfigure_free_connections( @@ -388,16 +398,19 @@ def handle_node_moved_event(self): tmp_host_address=None, tmp_relax_timeout=-1, ) + # Clear state to NONE in kwargs immediately after updating tmp kwargs + self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE) with self.pool._lock: if self.config.is_relax_timeouts_enabled(): # reset the timeout for existing connections self.pool.update_connections_current_timeout( relax_timeout=-1, include_free_connections=True ) - self.pool.update_connections_tmp_settings( tmp_host_address=None, tmp_relax_timeout=-1 ) + # Clear state to NONE for all connections + self.pool.set_maintenance_state_for_all(MaintenanceState.NONE) class MaintenanceEventConnectionHandler: @@ -416,17 +429,24 @@ def handle_event(self, event: MaintenanceEvent): logging.error(f"Unhandled event type: {event}") def handle_migrating_event(self, notification: NodeMigratingEvent): - if not self.config.is_relax_timeouts_enabled(): + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): return - + self.connection.set_maintenance_state(MaintenanceState.MIGRATING) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): - if not self.config.is_relax_timeouts_enabled(): + # Only reset timeouts if state is not MOVING and relax timeouts are enabled + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): return - + self.connection.set_maintenance_state(MaintenanceState.NONE) # Node migration completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 4518cd7290..880b6db27e 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -9,6 +9,7 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool +from redis.maintenance_events import MaintenanceState from redis.utils import SSL_AVAILABLE from .conftest import ( @@ -53,10 +54,15 @@ def get_pool( return pool def test_connection_creation(self): - connection_kwargs = {"foo": "bar", "biz": "baz"} + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "maintenance_state": MaintenanceState.NONE, + } pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -152,7 +158,9 @@ def test_connection_creation(self, master_host): "host": master_host[0], "port": master_host[1], } + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection_kwargs["maintenance_state"] = MaintenanceState.NONE connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 6ce74ebc3b..b573a55e5f 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -6,11 +6,18 @@ from time import sleep from redis import Redis -from redis.connection import AbstractConnection, ConnectionPool, BlockingConnectionPool +from redis.connection import ( + AbstractConnection, + ConnectionPool, + BlockingConnectionPool, + MaintenanceState, +) from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, MaintenanceEventPoolHandler, + NodeMovingEvent, + NodeMigratedEvent, ) @@ -326,24 +333,25 @@ def _validate_connected(self, expected_count): assert connected_sockets_count == expected_count def _validate_in_use_connections_state( - self, in_use_connections: List[AbstractConnection] + self, + in_use_connections: List[AbstractConnection], + expected_state=MaintenanceState.NONE, + expected_tmp_host_address=None, + expected_tmp_relax_timeout=-1, + expected_current_socket_timeout=None, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ): """Helper method to validate state of in-use connections.""" # validate in use connections are still working with set flag for reconnect # and timeout is updated for connection in in_use_connections: assert connection._should_reconnect is True - assert ( - connection.tmp_host_address - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) - assert connection.tmp_relax_timeout == self.config.relax_timeout - assert connection._sock.gettimeout() == self.config.relax_timeout + assert connection.tmp_host_address == expected_tmp_host_address + assert connection.tmp_relax_timeout == expected_tmp_relax_timeout + assert connection._sock.gettimeout() == expected_current_socket_timeout assert connection._sock.connected is True - assert ( - connection._sock.getpeername()[0] - == MockSocket.DEFAULT_ADDRESS.split(":")[0] - ) + assert connection.maintenance_state == expected_state + assert connection._sock.getpeername()[0] == expected_current_peername def _validate_free_connections_state( self, @@ -352,39 +360,30 @@ def _validate_free_connections_state( relax_timeout, should_be_connected_count, connected_to_tmp_addres=False, + expected_state=MaintenanceState.MOVING, ): """Helper method to validate state of free/available connections.""" if isinstance(pool, BlockingConnectionPool): - # BlockingConnectionPool uses _connections list where created connections are stored - # but we need to get the ones in the queue - these are the free ones - # the uninitialized connections are filtered out free_connections = [conn for conn in pool.pool.queue if conn is not None] elif isinstance(pool, ConnectionPool): - # Regular ConnectionPool uses _available_connections for free connections free_connections = pool._available_connections else: raise ValueError(f"Unsupported pool type: {type(pool)}") connected_count = 0 - # Validate fields that are validated in the validation of the active connections for connection in free_connections: - # Validate the same fields as in _validate_in_use_connections_state assert connection._should_reconnect is False assert connection.tmp_host_address == tmp_host_address assert connection.tmp_relax_timeout == relax_timeout + assert connection.maintenance_state == expected_state if connection._sock is not None: - connected_count += 1 - + assert connection._sock.connected is True if connected_to_tmp_addres: assert ( connection._sock.getpeername()[0] == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] ) - else: - assert ( - connection._sock.getpeername()[0] - == MockSocket.DEFAULT_ADDRESS.split(":")[0] - ) + connected_count += 1 assert connected_count == should_be_connected_count def _validate_all_timeouts(self, expected_timeout): @@ -804,7 +803,6 @@ def test_moving_related_events_handling_integration(self, pool_class): assert result2 is True, "Command 2 (SET key_receive_moving) failed" # Validate pool and connections settings were updated according to MOVING event - # handling expectations self._validate_conn_kwargs( test_redis_client.connection_pool, MockSocket.DEFAULT_ADDRESS.split(":")[0], @@ -812,25 +810,28 @@ def test_moving_related_events_handling_integration(self, pool_class): MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], self.config.relax_timeout, ) - # 5 disconnects has happened, 1 of them is with reconnect self._validate_disconnected(5) - # 5 in use connected + 1 after reconnect self._validate_connected(6) - self._validate_in_use_connections_state(in_use_connections) - # Validate there is 1 free connection that is connected - # the one that has handled the MOVING should reconnect after parsing the response + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[ + 0 + ], # the in use connections reconnect when they complete their current task + ) self._validate_free_connections_state( test_redis_client.connection_pool, MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], self.config.relax_timeout, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, ) - # Wait for MOVING timeout to expire and the moving completed handler to run - print("Waiting for MOVING timeout to expire...") sleep(MockSocket.MOVING_TIMEOUT + 0.5) - self._validate_all_timeouts(None) self._validate_conn_kwargs( test_redis_client.connection_pool, @@ -845,8 +846,8 @@ def test_moving_related_events_handling_integration(self, pool_class): -1, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.NONE, ) - finally: if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() @@ -972,7 +973,6 @@ def test_create_new_conn_after_moving_expires(self, pool_class): assert result is True, "SET key_receive_moving command failed" # Wait for MOVING timeout to expire - print("Waiting for MOVING timeout to expire...") sleep(MockSocket.MOVING_TIMEOUT + 0.5) # Now get several new connections after expiration @@ -1137,7 +1137,14 @@ def test_overlapping_moving_events(self, pool_class): self.config.relax_timeout, ) # Validate all connections reflect the first MOVING event - self._validate_in_use_connections_state(in_use_connections) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) self._validate_free_connections_state( test_redis_client.connection_pool, MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], @@ -1164,7 +1171,14 @@ def test_overlapping_moving_events(self, pool_class): self.config.relax_timeout, ) # Validate all connections reflect the second MOVING event - self._validate_in_use_connections_state(in_use_connections) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=new_address.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) self._validate_free_connections_state( test_redis_client.connection_pool, new_address.split(":")[0], @@ -1228,3 +1242,110 @@ def worker(idx): ) if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): + """ + Test moving configs are not lost if the per connection events get picked up after moving is handled. + MOVING → MIGRATING → MIGRATED → MOVED + Checks the state after each event for all connections and for new connections created during each state. + """ + # Setup + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + pool = test_redis_client.connection_pool + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + in_use_connections = [] + for _ in range(3): + in_use_connections.append(pool.get_connection()) + while len(in_use_connections) > 0: + pool.release(in_use_connections.pop()) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = pool.get_connection() + in_use_connections.append(conn) + + # 1. MOVING event + tmp_address = "22.23.24.25" + moving_event = NodeMovingEvent( + id=1, new_node_host=tmp_address, new_node_port=6379, ttl=1 + ) + pool_handler.handle_event(moving_event) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + pool, + tmp_address, + self.config.relax_timeout, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.MOVING, + ) + + # 2. MIGRATING event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratingEvent(id=2, ttl=1) + ) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 3. MIGRATED event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratedEvent(id=2) + ) + # State should not change for connections that are in MOVING state + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 4. MOVED event (simulate timer expiry) + pool_handler.handle_node_moved_event() + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_tmp_host_address=None, + expected_tmp_relax_timeout=-1, + expected_current_socket_timeout=None, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + pool, + None, + -1, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.NONE, + ) + # New connection after MOVED + new_conn_none = pool.get_connection() + assert new_conn_none.maintenance_state == MaintenanceState.NONE + pool.release(new_conn_none) + # Cleanup + for conn in in_use_connections: + pool.release(conn) + if hasattr(pool, "disconnect"): + pool.disconnect() From 08f158578cdbfe2321e5238a9c251d7c44a00ed9 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 22 Jul 2025 19:27:57 +0300 Subject: [PATCH 14/28] Refactored the tmp host address and timeout storing and the way to apply them during connect --- redis/_parsers/base.py | 2 - redis/connection.py | 383 ++++++++++++++-------- redis/maintenance_events.py | 71 ++-- tests/test_maintenance_events.py | 21 +- tests/test_maintenance_events_handling.py | 360 ++++++++++++-------- 5 files changed, 536 insertions(+), 301 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index f2670e43b0..77d0188092 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -130,8 +130,6 @@ def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock timeout = connection.socket_timeout - if connection.tmp_relax_timeout != -1: - timeout = connection.tmp_relax_timeout self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout) self.encoder = connection.encoder diff --git a/redis/connection.py b/redis/connection.py index 5646a745af..c20c89dd9d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,5 +1,4 @@ import copy -import logging import os import socket import sys @@ -236,22 +235,62 @@ def re_auth(self): @abstractmethod def mark_for_reconnect(self): + """ + Mark the connection to be reconnected on the next command. + This is useful when a connection is moved to a different node. + """ pass @abstractmethod def should_reconnect(self): + """ + Returns True if the connection should be reconnected. + """ + pass + + @property + @abstractmethod + def maintenance_state(self) -> MaintenanceState: + """ + Returns the current maintenance state of the connection. + """ + pass + + @maintenance_state.setter + @abstractmethod + def maintenance_state(self, state: "MaintenanceState"): + """ + Sets the current maintenance state of the connection. + """ pass @abstractmethod def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + """ + Update the timeout for the current socket. + """ pass @abstractmethod - def update_tmp_settings( + def set_tmp_settings( self, tmp_host_address: Optional[str] = None, tmp_relax_timeout: Optional[float] = None, ): + """ + Updates temporary host address and timeout settings for the connection. + """ + pass + + @abstractmethod + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + """ + Resets temporary host address and timeout settings for the connection. + """ pass @@ -284,8 +323,9 @@ def __init__( event_dispatcher: Optional[EventDispatcher] = None, maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, maintenance_events_config: Optional[MaintenanceEventsConfig] = None, - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = -1, + orig_host_address: Optional[str] = None, + orig_socket_timeout: Optional[float] = None, + orig_socket_connect_timeout: Optional[float] = None, maintenance_state: "MaintenanceState" = MaintenanceState.NONE, ): """ @@ -374,8 +414,9 @@ def __init__( self._command_packer = self._construct_command_packer(command_packer) self._should_reconnect = False - self.tmp_host_address = tmp_host_address - self.tmp_relax_timeout = tmp_relax_timeout + self.orig_host_address = orig_host_address + self.orig_socket_timeout = orig_socket_timeout + self.orig_socket_connect_timeout = orig_socket_connect_timeout self.maintenance_state = maintenance_state def __repr__(self): @@ -809,6 +850,14 @@ def re_auth(self): self.read_response() self._re_auth_token = None + @property + def maintenance_state(self) -> MaintenanceState: + return self._maintenance_state + + @maintenance_state.setter + def maintenance_state(self, state: "MaintenanceState"): + self._maintenance_state = state + def mark_for_reconnect(self): self._should_reconnect = True @@ -825,21 +874,40 @@ def update_parser_buffer_timeout(self, timeout: Optional[float] = None): if self._parser and self._parser._buffer: self._parser._buffer.socket_timeout = timeout - def update_tmp_settings( + def set_tmp_settings( self, tmp_host_address: Optional[Union[str, object]] = SENTINEL, - tmp_relax_timeout: Optional[Union[float, object]] = SENTINEL, + tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ The value of SENTINEL is used to indicate that the property should not be updated. """ if tmp_host_address is not SENTINEL: - self.tmp_host_address = tmp_host_address - if tmp_relax_timeout is not SENTINEL: - self.tmp_relax_timeout = tmp_relax_timeout - - def set_maintenance_state(self, state: "MaintenanceState"): - self.maintenance_state = state + if not skip_original_data_update: + self.orig_host_address = self.host + self.host = tmp_host_address + if tmp_relax_timeout != -1: + if not skip_original_data_update: + self.orig_socket_timeout = self.socket_timeout + self.orig_socket_connect_timeout = self.socket_connect_timeout + + self.socket_timeout = tmp_relax_timeout + self.socket_connect_timeout = tmp_relax_timeout + + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + if reset_host_address: + self.host = self.orig_host_address + self.orig_host_address = None + if reset_relax_timeout: + self.socket_timeout = self.orig_socket_timeout + self.socket_connect_timeout = self.orig_socket_connect_timeout + self.orig_socket_timeout = None + self.orig_socket_connect_timeout = None class Connection(AbstractConnection): @@ -873,10 +941,9 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None - host = self.tmp_host_address or self.host for res in socket.getaddrinfo( - host, self.port, self.socket_type, socket.SOCK_STREAM + self.host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -892,19 +959,13 @@ def _connect(self): sock.setsockopt(socket.IPPROTO_TCP, k, v) # set the socket_connect_timeout before we connect - if self.tmp_relax_timeout != -1: - sock.settimeout(self.tmp_relax_timeout) - else: - sock.settimeout(self.socket_connect_timeout) + sock.settimeout(self.socket_connect_timeout) # connect sock.connect(socket_address) # set the socket_timeout now that we're connected - if self.tmp_relax_timeout != -1: - sock.settimeout(self.tmp_relax_timeout) - else: - sock.settimeout(self.socket_timeout) + sock.settimeout(self.socket_timeout) return sock except OSError as _: @@ -1818,54 +1879,128 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) - def update_connection_kwargs_with_tmp_settings( + def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + for conn in self._available_connections: + conn.maintenance_state = state + for conn in self._in_use_connections: + conn.maintenance_state = state + + def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state + + def add_tmp_config_to_connection_kwargs( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ - Update the connection kwargs with the temporary host address and the - relax timeout(if enabled). - This is used when a cluster node is rebind to a different address. + Store original connection configuration and apply temporary settings. + + This method saves the current host, socket_timeout, and socket_connect_timeout values + in temporary storage fields (orig_*), then applies the provided temporary values + as the active connection configuration. + + This is used when a cluster node is rebound to a different address during + maintenance operations. New connections created after this call will use the + temporary configuration until remove_tmp_config_from_connection_kwargs() is called. + + When this method is called the pool will already be locked, so getting the pool + lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for new connections. + This parameter is required and will replace the current host. + :param tmp_relax_timeout: The temporary timeout value to use for both socket_timeout + and socket_connect_timeout. If -1 is provided, the timeout + settings are not modified (relax timeout is disabled). + :param skip_original_data_update: Whether to skip updating the original data. + This is used when we are already in MOVING state + and the original data is already stored in the connection kwargs. + """ + if not skip_original_data_update: + # Store original values in temporary storage + original_host = self.connection_kwargs.get("host") + original_socket_timeout = self.connection_kwargs.get("socket_timeout") + original_connect_timeout = self.connection_kwargs.get( + "socket_connect_timeout" + ) - When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - This new address will be used to create new connections until the old node is decomissioned. + self.connection_kwargs.update( + { + "orig_host_address": original_host, + "orig_socket_timeout": original_socket_timeout, + "orig_socket_connect_timeout": original_connect_timeout, + } + ) - :param tmp_host_address: The temporary host address to use for the connection. - :param tmp_relax_timeout: The relax timeout to use for the connection. - If -1 is provided - the relax timeout is disabled, so the tmp property is not set + # Apply temporary values as active configuration + self.connection_kwargs.update({"host": tmp_host_address}) + + if tmp_relax_timeout != -1: + self.connection_kwargs.update( + { + "socket_timeout": tmp_relax_timeout, + "socket_connect_timeout": tmp_relax_timeout, + } + ) + + def remove_tmp_config_from_connection_kwargs(self): """ - self.connection_kwargs.update({"tmp_host_address": tmp_host_address}) - self.connection_kwargs.update({"tmp_relax_timeout": tmp_relax_timeout}) + Remove temporary configuration from connection kwargs and restore original values. - def update_connections_tmp_settings( - self, - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = None, - ): + This method restores the original host address, socket timeout, and connect timeout + from their temporary storage back to the main connection kwargs, then clears the + temporary storage fields. + + This is typically called when a cluster node maintenance operation is complete + and the connection should revert to its original configuration. + + When this method is called the pool will already be locked, so getting the pool + lock inside is not needed. """ - Update the tmp settings for all connections in the pool. - This is used when a cluster node is rebind to a different address. + orig_host = self.connection_kwargs.get("orig_host_address") + orig_socket_timeout = self.connection_kwargs.get("orig_socket_timeout") + orig_connect_timeout = self.connection_kwargs.get("orig_socket_connect_timeout") - When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + self.connection_kwargs.update( + { + "orig_host_address": None, + "orig_socket_timeout": None, + "orig_socket_connect_timeout": None, + "host": orig_host, + "socket_timeout": orig_socket_timeout, + "socket_connect_timeout": orig_connect_timeout, + } + ) + + def reset_connections_tmp_settings(self): + """ + Restore original settings from temporary configuration for all connections in the pool. - :param tmp_host_address: The temporary host address to use for the connection. - :param tmp_relax_timeout: The relax timeout to use for the connection. + This method restores each connection's original host, socket_timeout, and socket_connect_timeout + values from their orig_* attributes back to the active connection configuration, then clears + the temporary storage attributes. + + This is used to restore connections to their original configuration after maintenance operations + that required temporary address/timeout changes are complete. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. """ with self._lock: for conn in self._available_connections: - self._update_connection_tmp_settings( - conn, tmp_host_address, tmp_relax_timeout + conn.reset_tmp_settings( + reset_host_address=True, reset_relax_timeout=True ) for conn in self._in_use_connections: - self._update_connection_tmp_settings( - conn, tmp_host_address, tmp_relax_timeout + conn.reset_tmp_settings( + reset_host_address=True, reset_relax_timeout=True ) def update_active_connections_for_reconnect( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ Mark all active connections for reconnect. @@ -1873,17 +2008,18 @@ def update_active_connections_for_reconnect( When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - :param tmp_host_address: The temporary host address to use for the connection. + :param orig_host_address: The temporary host address to use for the connection. """ for conn in self._in_use_connections: self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout + conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update ) def disconnect_and_reconfigure_free_connections( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ Disconnect all free/available connections. @@ -1891,13 +2027,13 @@ def disconnect_and_reconfigure_free_connections( When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - :param tmp_host_address: The temporary host address to use for the connection. - :param tmp_relax_timeout: The relax timeout to use for the connection. + :param orig_host_address: The temporary host address to use for the connection. + :param orig_relax_timeout: The relax timeout to use for the connection. """ for conn in self._available_connections: self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout + conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update ) def update_connections_current_timeout( @@ -1916,48 +2052,40 @@ def update_connections_current_timeout( :param include_available_connections: Whether to include available connections in the update. """ for conn in self._in_use_connections: - self._update_connection_timeout(conn, relax_timeout) + conn.update_current_socket_timeout(relax_timeout) if include_free_connections: for conn in self._available_connections: - self._update_connection_timeout(conn, relax_timeout) + conn.update_current_socket_timeout(relax_timeout) def _update_connection_for_reconnect( self, connection: "Connection", - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): connection.mark_for_reconnect() - self._update_connection_tmp_settings( - connection, tmp_host_address, tmp_relax_timeout + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, + tmp_relax_timeout=tmp_relax_timeout, + skip_original_data_update=skip_original_data_update, ) def _disconnect_and_update_connection_for_reconnect( self, connection: "Connection", - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): connection.disconnect() - self._update_connection_tmp_settings( - connection, tmp_host_address, tmp_relax_timeout + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, + tmp_relax_timeout=tmp_relax_timeout, + skip_original_data_update=skip_original_data_update, ) - def _update_connection_tmp_settings( - self, - connection: "Connection", - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = None, - ): - connection.tmp_host_address = tmp_host_address - connection.tmp_relax_timeout = tmp_relax_timeout - - def _update_connection_timeout( - self, connection: "Connection", relax_timeout: Optional[Number] - ): - connection.update_current_socket_timeout(relax_timeout) - async def _mock(self, error: RedisError): """ Dummy functions, needs to be passed as error callback to retry object. @@ -1966,16 +2094,6 @@ async def _mock(self, error: RedisError): """ pass - def set_maintenance_state_for_all(self, state: "MaintenanceState"): - with self._lock: - for conn in self._available_connections: - conn.set_maintenance_state(state) - for conn in self._in_use_connections: - conn.set_maintenance_state(state) - - def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): - self.connection_kwargs["maintenance_state"] = state - class BlockingConnectionPool(ConnectionPool): """ @@ -2215,67 +2333,54 @@ def disconnect(self): def update_active_connections_for_reconnect( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): with self._lock: connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout + conn, + tmp_host_address, + tmp_relax_timeout, + skip_original_data_update, ) def disconnect_and_reconfigure_free_connections( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[Number] = None, + skip_original_data_update: bool = False, ): - with self._lock: - existing_connections = self.pool.queue + existing_connections = self.pool.queue - for conn in existing_connections: - if conn: - self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout - ) + for conn in existing_connections: + if conn: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + ) def update_connections_current_timeout( self, relax_timeout: Optional[float] = None, include_free_connections: bool = False, ): - logging.debug( - f"***** Blocking Pool --> Updating timeouts. relax_timeout: {relax_timeout}" - ) - - with self._lock: - if include_free_connections: - for conn in tuple(self._connections): - self._update_connection_timeout(conn, relax_timeout) - else: - connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: - self._update_connection_timeout(conn, relax_timeout) - - def update_connections_tmp_settings( - self, - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = None, - ): - with self._lock: + if include_free_connections: for conn in tuple(self._connections): - self._update_connection_tmp_settings( - conn, tmp_host_address, tmp_relax_timeout - ) + conn.update_current_socket_timeout(relax_timeout) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + conn.update_current_socket_timeout(relax_timeout) def _update_maintenance_events_config_for_connections( self, maintenance_events_config ): - with self._lock: - for conn in tuple(self._connections): - conn.maintenance_events_config = maintenance_events_config + for conn in tuple(self._connections): + conn.maintenance_events_config = maintenance_events_config def _update_maintenance_events_configs_for_connections( self, maintenance_events_pool_handler @@ -2283,23 +2388,25 @@ def _update_maintenance_events_configs_for_connections( """Override base class method to work with BlockingConnectionPool's structure.""" with self._lock: for conn in tuple(self._connections): - if conn: # conn can be None in BlockingConnectionPool - conn.set_maintenance_event_pool_handler( - maintenance_events_pool_handler - ) - conn.maintenance_events_config = ( - maintenance_events_pool_handler.config - ) + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + + def reset_connections_tmp_settings(self): + """ + Override base class method to work with BlockingConnectionPool's structure. + + Restore original settings from temporary configuration for all connections in the pool. + """ + for conn in tuple(self._connections): + conn.reset_tmp_settings(reset_host_address=True, reset_relax_timeout=True) def set_in_maintenance(self, in_maintenance: bool): """Set the maintenance mode for the connection pool.""" self._in_maintenance = in_maintenance - def set_maintenance_state_for_all(self, state: "MaintenanceState"): - with self._lock: - for conn in getattr(self, "_connections", []): - if conn: - conn.set_maintenance_state(state) + def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + for conn in self._connections: + conn.maintenance_state = state - def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): self.connection_kwargs["maintenance_state"] = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index dd62602105..479a4ba090 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -359,15 +359,35 @@ def handle_node_moving_event(self, event: NodeMovingEvent): ): if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(True) - # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) - self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING) - self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING) + + prev_moving_in_progress = False + if ( + self.pool.connection_kwargs.get("maintenance_state") + == MaintenanceState.MOVING + ): + # The pool is already in MOVING state, update just the new host information + prev_moving_in_progress = True + + if not prev_moving_in_progress: + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_all_connections( + MaintenanceState.MOVING + ) + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.MOVING + ) # edit the config for new connections until the notification expires - self.pool.update_connection_kwargs_with_tmp_settings( + # skip original data update if we are already in MOVING state + # as the original data is already stored in the connection kwargs + self.pool.add_tmp_config_to_connection_kwargs( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + skip_original_data_update=prev_moving_in_progress, ) - if self.config.is_relax_timeouts_enabled(): + if ( + self.config.is_relax_timeouts_enabled() + and not prev_moving_in_progress + ): # extend the timeout for all connections that are currently in use self.pool.update_connections_current_timeout( self.config.relax_timeout @@ -375,42 +395,53 @@ def handle_node_moving_event(self, event: NodeMovingEvent): if self.config.proactive_reconnect: # take care for the active connections in the pool # mark them for reconnect after they complete the current command + # skip original data update if we are already in MOVING state + # as the original data is already stored in the connection self.pool.update_active_connections_for_reconnect( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + skip_original_data_update=prev_moving_in_progress, ) # take care for the inactive connections in the pool # delete them and create new ones + # skip original data update if we are already in MOVING state + # as the original data is already stored in the connection self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + skip_original_data_update=prev_moving_in_progress, ) if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(False) - threading.Timer(event.ttl, self.handle_node_moved_event).start() + threading.Timer( + event.ttl, self.handle_node_moved_event, args=(event,) + ).start() self._processed_events.add(event) - def handle_node_moved_event(self): + def handle_node_moved_event(self, event: NodeMovingEvent): with self._lock: - self.pool.update_connection_kwargs_with_tmp_settings( - tmp_host_address=None, - tmp_relax_timeout=-1, - ) + if self.pool.connection_kwargs.get("host") != event.new_node_host: + # if the current host is not matching the event + # it means there has been a new moving event after this one + # so we don't need to handle this one anymore + # the settings will be reverted by the moved handler of the next event + return + self.pool.remove_tmp_config_from_connection_kwargs() # Clear state to NONE in kwargs immediately after updating tmp kwargs - self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE) + self.pool.set_maintenance_state_in_connection_kwargs(MaintenanceState.NONE) with self.pool._lock: + self.pool.reset_connections_tmp_settings() if self.config.is_relax_timeouts_enabled(): # reset the timeout for existing connections self.pool.update_connections_current_timeout( relax_timeout=-1, include_free_connections=True ) - self.pool.update_connections_tmp_settings( - tmp_host_address=None, tmp_relax_timeout=-1 - ) # Clear state to NONE for all connections - self.pool.set_maintenance_state_for_all(MaintenanceState.NONE) + self.pool.set_maintenance_state_for_all_connections( + MaintenanceState.NONE + ) class MaintenanceEventConnectionHandler: @@ -434,10 +465,10 @@ def handle_migrating_event(self, notification: NodeMigratingEvent): or not self.config.is_relax_timeouts_enabled() ): return - self.connection.set_maintenance_state(MaintenanceState.MIGRATING) + self.connection.maintenance_state = MaintenanceState.MIGRATING + self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) - self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): # Only reset timeouts if state is not MOVING and relax timeouts are enabled @@ -446,8 +477,8 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): or not self.config.is_relax_timeouts_enabled() ): return - self.connection.set_maintenance_state(MaintenanceState.NONE) + self.connection.reset_tmp_settings(reset_relax_timeout=True) # Node migration completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) - self.connection.update_tmp_settings(tmp_relax_timeout=-1) + self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index ac7d10b51e..37ef869100 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -438,7 +438,7 @@ def test_handle_node_moving_event_success(self): # Verify timer was started mock_timer.assert_called_once_with( - event.ttl, self.handler.handle_node_moved_event + event.ttl, self.handler.handle_node_moved_event, args=(event,) ) mock_timer.return_value.start.assert_called_once() @@ -446,17 +446,18 @@ def test_handle_node_moving_event_success(self): assert event in self.handler._processed_events # Verify pool methods were called - self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once() + self.mock_pool.add_tmp_config_to_connection_kwargs.assert_called_once() def test_handle_node_moved_event(self): """Test handling of node moved event (cleanup).""" - self.handler.handle_node_moved_event() + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.mock_pool.connection_kwargs = {"host": "localhost"} + self.handler.handle_node_moved_event(event) # Verify cleanup methods were called - self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once_with( - tmp_host_address=None, - tmp_relax_timeout=-1, - ) + self.mock_pool.remove_tmp_config_from_connection_kwargs.assert_called_once() class TestMaintenanceEventConnectionHandler: @@ -519,7 +520,7 @@ def test_handle_migrating_event_success(self): self.handler.handle_migrating_event(event) self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) - self.mock_connection.update_tmp_settings.assert_called_once_with( + self.mock_connection.set_tmp_settings.assert_called_once_with( tmp_relax_timeout=20 ) @@ -540,6 +541,6 @@ def test_handle_migration_completed_event_success(self): self.handler.handle_migration_completed_event(event) self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) - self.mock_connection.update_tmp_settings.assert_called_once_with( - tmp_relax_timeout=-1 + self.mock_connection.reset_tmp_settings.assert_called_once_with( + reset_relax_timeout=True ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index b573a55e5f..fe0b529fdb 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -280,25 +280,6 @@ def _validate_connection_handlers(self, conn, pool_handler, config): # Validate that the connection's maintenance handler has the same config object assert conn._maintenance_event_connection_handler.config is config - def _validate_current_timeout_for_thread( - self, thread_id, expected_timeout, error_msg=None - ): - """Helper method to validate the current timeout for the calling thread.""" - actual_timeout = None - # Get the actual thread ID from the current thread - current_thread_id = threading.current_thread().ident - for sock in self.mock_sockets: - if current_thread_id in sock.thread_timeouts: - actual_timeout = sock.thread_timeouts[current_thread_id] - break - - assert actual_timeout == expected_timeout, ( - error_msg, - f"Thread {thread_id}: Expected timeout ({expected_timeout}), " - f"but found timeout: {actual_timeout} for thread {thread_id}. " - f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", - ) - def _validate_current_timeout(self, expected_timeout, error_msg=None): """Helper method to validate the current timeout for the calling thread.""" actual_timeout = None @@ -336,8 +317,12 @@ def _validate_in_use_connections_state( self, in_use_connections: List[AbstractConnection], expected_state=MaintenanceState.NONE, - expected_tmp_host_address=None, - expected_tmp_relax_timeout=-1, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ): @@ -346,21 +331,33 @@ def _validate_in_use_connections_state( # and timeout is updated for connection in in_use_connections: assert connection._should_reconnect is True - assert connection.tmp_host_address == expected_tmp_host_address - assert connection.tmp_relax_timeout == expected_tmp_relax_timeout - assert connection._sock.gettimeout() == expected_current_socket_timeout - assert connection._sock.connected is True + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + if connection._sock is not None: + assert connection._sock.gettimeout() == expected_current_socket_timeout + assert connection._sock.connected is True + assert connection._sock.getpeername()[0] == expected_current_peername assert connection.maintenance_state == expected_state - assert connection._sock.getpeername()[0] == expected_current_peername def _validate_free_connections_state( self, pool, - tmp_host_address, - relax_timeout, should_be_connected_count, connected_to_tmp_addres=False, expected_state=MaintenanceState.MOVING, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ): """Helper method to validate state of free/available connections.""" if isinstance(pool, BlockingConnectionPool): @@ -373,8 +370,15 @@ def _validate_free_connections_state( connected_count = 0 for connection in free_connections: assert connection._should_reconnect is False - assert connection.tmp_host_address == tmp_host_address - assert connection.tmp_relax_timeout == relax_timeout + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) assert connection.maintenance_state == expected_state if connection._sock is not None: assert connection._sock.connected is True @@ -401,14 +405,29 @@ def _validate_conn_kwargs( pool, expected_host_address, expected_port, - expected_tmp_host_address, - expected_tmp_relax_timeout, + expected_socket_timeout, + expected_socket_connect_timeout, + expected_orig_host_address, + expected_orig_socket_timeout, + expected_orig_socket_connect_timeout, ): """Helper method to validate connection kwargs.""" assert pool.connection_kwargs["host"] == expected_host_address assert pool.connection_kwargs["port"] == expected_port - assert pool.connection_kwargs["tmp_host_address"] == expected_tmp_host_address - assert pool.connection_kwargs["tmp_relax_timeout"] == expected_tmp_relax_timeout + assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout + assert ( + pool.connection_kwargs["socket_connect_timeout"] + == expected_socket_connect_timeout + ) + assert pool.connection_kwargs["orig_host_address"] == expected_orig_host_address + assert ( + pool.connection_kwargs["orig_socket_timeout"] + == expected_orig_socket_timeout + ) + assert ( + pool.connection_kwargs["orig_socket_connect_timeout"] + == expected_orig_socket_connect_timeout + ) def test_client_initialization(self): """Test that Redis client is created with maintenance events configuration.""" @@ -427,8 +446,8 @@ def test_client_initialization(self): conn = test_redis_client.connection_pool.get_connection() assert conn._should_reconnect is False - assert conn.tmp_host_address is None - assert conn.tmp_relax_timeout == -1 + assert conn.orig_host_address is None + assert conn.orig_socket_timeout is None # Test that the node moving handler function is correctly set by # comparing the underlying function and instance @@ -764,7 +783,8 @@ def test_moving_related_events_handling_integration(self, pool_class): This test validates the complete MOVING event lifecycle: 1. Creates multiple connections in the pool 2. Executes a Redis command that triggers a MOVING push message - 3. Validates that pool configuration is updated with temporary address and timeout + 3. Validates that pool configuration is updated with temporary + address and timeout - for new connections creation 4. Validates that existing connections are marked for disconnection 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ @@ -804,46 +824,64 @@ def test_moving_related_events_handling_integration(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) self._validate_disconnected(5) self._validate_connected(6) self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[ 0 ], # the in use connections reconnect when they complete their current task ) self._validate_free_connections_state( - test_redis_client.connection_pool, - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_state=MaintenanceState.MOVING, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, should_be_connected_count=1, connected_to_tmp_addres=True, - expected_state=MaintenanceState.MOVING, ) # Wait for MOVING timeout to expire and the moving completed handler to run sleep(MockSocket.MOVING_TIMEOUT + 0.5) self._validate_all_timeouts(None) self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - None, - -1, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) self._validate_free_connections_state( - test_redis_client.connection_pool, - None, - -1, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, should_be_connected_count=1, connected_to_tmp_addres=True, expected_state=MaintenanceState.NONE, @@ -896,11 +934,14 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) # Now get several more connections to force creation of new ones @@ -915,11 +956,8 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): # Validate that new connections are created with temporary address and relax timeout # and when connecting those configs are used # get_connection() returns a connection that is already connected - assert ( - new_connection.tmp_host_address - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) - assert new_connection.tmp_relax_timeout == self.config.relax_timeout + assert new_connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + assert new_connection.socket_timeout is self.config.relax_timeout # New connections should be connected to the temporary address assert new_connection._sock is not None assert new_connection._sock.connected is True @@ -984,8 +1022,8 @@ def test_create_new_conn_after_moving_expires(self, pool_class): new_connection = test_redis_client.connection_pool.get_connection() # Validate that new connections are created with original address (no temporary settings) - assert new_connection.tmp_host_address is None - assert new_connection.tmp_relax_timeout == -1 + assert new_connection.orig_host_address is None + assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address assert new_connection._sock is not None assert new_connection._sock.connected is True @@ -1044,13 +1082,18 @@ def test_receive_migrated_after_moving(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) + # TODO validate current socket timeout + # Step 2: Run command that will receive and handle MIGRATED event # This should clear the temporary settings key_migrated = "key_receive_migrated_0" @@ -1062,14 +1105,17 @@ def test_receive_migrated_after_moving(self, pool_class): # Step 3: Validate that MIGRATED event was processed but MOVING settings remain # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) + # MOVING settings should still be active + # MOVING timeout should still be active self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[ - 0 - ], # MOVING settings still active - self.config.relax_timeout, # MOVING timeout still active + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) # Step 4: Create new connections after MIGRATED to verify they still use MOVING settings @@ -1081,10 +1127,7 @@ def test_receive_migrated_after_moving(self, pool_class): # Validate that new connections are created with MOVING settings (still active) for connection in new_connections: - assert ( - connection.tmp_host_address - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) + assert connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] # Note: New connections may not inherit the exact relax timeout value # but they should have the temporary host address # New connections should be connected @@ -1130,61 +1173,85 @@ def test_overlapping_moving_events(self, pool_class): result1 = test_redis_client.set(key_moving1, value_moving1) assert result1 is True self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the first MOVING event self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( - test_redis_client.connection_pool, - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # Before the first MOVING expires, trigger a second MOVING event (simulate new address) - # Patch MockSocket to use a new address for the second event - new_address = "5.6.7.8:6380" + # Validate the orig properties are not changed! + second_moving_address = "5.6.7.8:6380" orig_after_moving = MockSocket.AFTER_MOVING_ADDRESS - MockSocket.AFTER_MOVING_ADDRESS = new_address + MockSocket.AFTER_MOVING_ADDRESS = second_moving_address try: key_moving2 = "key_receive_moving_1" value_moving2 = "value3_1" result2 = test_redis_client.set(key_moving2, value_moving2) assert result2 is True self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - new_address.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_host_address=second_moving_address.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the second MOVING event self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=new_address.split(":")[0], - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( test_redis_client.connection_pool, - new_address.split(":")[0], - self.config.relax_timeout, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) finally: MockSocket.AFTER_MOVING_ADDRESS = orig_after_moving @@ -1192,11 +1259,14 @@ def test_overlapping_moving_events(self, pool_class): # Wait for both MOVING timeouts to expire sleep(MockSocket.MOVING_TIMEOUT + 0.5) self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - None, - -1, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) finally: if hasattr(test_redis_client.connection_pool, "disconnect"): @@ -1234,12 +1304,16 @@ def worker(idx): assert not errors, f"Errors occurred in threads: {errors}" # After all threads, MOVING event should have been handled safely self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) + if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() @@ -1279,18 +1353,26 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=tmp_address, - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( - pool, - tmp_address, - self.config.relax_timeout, + pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # 2. MIGRATING event (simulate direct connection handler call) @@ -1301,8 +1383,12 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=tmp_address, - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) @@ -1316,29 +1402,41 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=tmp_address, - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) # 4. MOVED event (simulate timer expiry) - pool_handler.handle_node_moved_event() + pool_handler.handle_node_moved_event(moving_event) self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.NONE, - expected_tmp_host_address=None, - expected_tmp_relax_timeout=-1, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( - pool, - None, - -1, + pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, expected_state=MaintenanceState.NONE, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # New connection after MOVED new_conn_none = pool.get_connection() From 9a31a71432819b9d98004e128f7a2bd8fe91bf70 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 24 Jul 2025 16:38:31 +0300 Subject: [PATCH 15/28] Apply review comments --- redis/_parsers/base.py | 20 +-- redis/connection.py | 186 +++++++++++----------- redis/maintenance_events.py | 7 - tests/test_connection_pool.py | 2 - tests/test_maintenance_events.py | 23 ++- tests/test_maintenance_events_handling.py | 38 +++-- 6 files changed, 134 insertions(+), 142 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 77d0188092..d5e4add661 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -129,8 +129,9 @@ def __del__(self): def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock - timeout = connection.socket_timeout - self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout) + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) self.encoder = connection.encoder def on_disconnect(self): @@ -201,19 +202,18 @@ def handle_push_response(self, response, **kwargs): if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: - if msg_type in _MOVING_MESSAGE: - host, port = response[2].decode().split(":") - ttl = response[1] - id = 1 # Hardcoded value for sync parser - notification = NodeMovingEvent(id, host, port, ttl) - return self.node_moving_push_handler_func(notification) + host, port = response[2].decode().split(":") + ttl = response[1] + id = 1 # Hardcoded value until the notification starts including the id + notification = NodeMovingEvent(id, host, port, ttl) + return self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: ttl = response[1] - id = 2 # Hardcoded value for sync parser + id = 2 # Hardcoded value until the notification starts including the id notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: - id = 3 # Hardcoded value for sync parser + id = 3 # Hardcoded value until the notification starts including the id notification = NodeMigratedEvent(id) else: notification = None diff --git a/redis/connection.py b/redis/connection.py index c20c89dd9d..3f1b54e26c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -233,34 +233,34 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @property @abstractmethod - def mark_for_reconnect(self): + def maintenance_state(self) -> MaintenanceState: """ - Mark the connection to be reconnected on the next command. - This is useful when a connection is moved to a different node. + Returns the current maintenance state of the connection. """ pass + @maintenance_state.setter @abstractmethod - def should_reconnect(self): + def maintenance_state(self, state: "MaintenanceState"): """ - Returns True if the connection should be reconnected. + Sets the current maintenance state of the connection. """ pass - @property @abstractmethod - def maintenance_state(self) -> MaintenanceState: + def mark_for_reconnect(self): """ - Returns the current maintenance state of the connection. + Mark the connection to be reconnected on the next command. + This is useful when a connection is moved to a different node. """ pass - @maintenance_state.setter @abstractmethod - def maintenance_state(self, state: "MaintenanceState"): + def should_reconnect(self): """ - Sets the current maintenance state of the connection. + Returns True if the connection should be reconnected. """ pass @@ -323,10 +323,10 @@ def __init__( event_dispatcher: Optional[EventDispatcher] = None, maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, maintenance_events_config: Optional[MaintenanceEventsConfig] = None, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, orig_host_address: Optional[str] = None, orig_socket_timeout: Optional[float] = None, orig_socket_connect_timeout: Optional[float] = None, - maintenance_state: "MaintenanceState" = MaintenanceState.NONE, ): """ Initialize a new Connection. @@ -412,13 +412,22 @@ def __init__( self._maintenance_event_connection_handler.handle_event ) - self._command_packer = self._construct_command_packer(command_packer) + self.orig_host_address = ( + orig_host_address if orig_host_address else self.host + ) + self.orig_socket_timeout = ( + orig_socket_timeout if orig_socket_timeout else self.socket_timeout + ) + self.orig_socket_connect_timeout = ( + orig_socket_connect_timeout + if orig_socket_connect_timeout + else self.socket_connect_timeout + ) self._should_reconnect = False - self.orig_host_address = orig_host_address - self.orig_socket_timeout = orig_socket_timeout - self.orig_socket_connect_timeout = orig_socket_connect_timeout self.maintenance_state = maintenance_state + self._command_packer = self._construct_command_packer(command_packer) + def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" @@ -878,20 +887,13 @@ def set_tmp_settings( self, tmp_host_address: Optional[Union[str, object]] = SENTINEL, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): """ The value of SENTINEL is used to indicate that the property should not be updated. """ if tmp_host_address is not SENTINEL: - if not skip_original_data_update: - self.orig_host_address = self.host self.host = tmp_host_address if tmp_relax_timeout != -1: - if not skip_original_data_update: - self.orig_socket_timeout = self.socket_timeout - self.orig_socket_connect_timeout = self.socket_connect_timeout - self.socket_timeout = tmp_relax_timeout self.socket_connect_timeout = tmp_relax_timeout @@ -902,12 +904,9 @@ def reset_tmp_settings( ): if reset_host_address: self.host = self.orig_host_address - self.orig_host_address = None if reset_relax_timeout: self.socket_timeout = self.orig_socket_timeout self.socket_connect_timeout = self.orig_socket_connect_timeout - self.orig_socket_timeout = None - self.orig_socket_connect_timeout = None class Connection(AbstractConnection): @@ -1600,6 +1599,24 @@ def __init__( raise RedisError( "Push handlers on connection are only supported with RESP version 3" ) + config = connection_kwargs.get("maintenance_events_config", None) or ( + connection_kwargs.get("maintenance_events_pool_handler").config + if connection_kwargs.get("maintenance_events_pool_handler") + else None + ) + + if config and config.enabled: + connection_kwargs.update( + { + "orig_host_address": connection_kwargs.get("host"), + "orig_socket_timeout": connection_kwargs.get( + "socket_timeout", None + ), + "orig_socket_connect_timeout": connection_kwargs.get( + "socket_connect_timeout", None + ), + } + ) self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: @@ -1641,7 +1658,7 @@ def maintenance_events_pool_handler_enabled(self): True if the maintenance events pool handler is enabled, False otherwise. """ maintenance_events_config = self.connection_kwargs.get( - "maintenance_events_config", False + "maintenance_events_config", None ) return maintenance_events_config and maintenance_events_config.enabled @@ -1663,6 +1680,7 @@ def set_maintenance_events_pool_handler( def _update_maintenance_events_configs_for_connections( self, maintenance_events_pool_handler ): + """Update the maintenance events config for all connections in the pool.""" with self._lock: for conn in self._available_connections: conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) @@ -1791,12 +1809,7 @@ def make_connection(self) -> "ConnectionInterface": raise MaxConnectionsError("Too many connections") self._created_connections += 1 - # Pass current maintenance_state to new connections - maintenance_state = self.connection_kwargs.get( - "maintenance_state", MaintenanceState.NONE - ) kwargs = dict(self.connection_kwargs) - kwargs["maintenance_state"] = maintenance_state if self.cache is not None: return CacheProxyConnection( @@ -1892,7 +1905,6 @@ def add_tmp_config_to_connection_kwargs( self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): """ Store original connection configuration and apply temporary settings. @@ -1913,26 +1925,7 @@ def add_tmp_config_to_connection_kwargs( :param tmp_relax_timeout: The temporary timeout value to use for both socket_timeout and socket_connect_timeout. If -1 is provided, the timeout settings are not modified (relax timeout is disabled). - :param skip_original_data_update: Whether to skip updating the original data. - This is used when we are already in MOVING state - and the original data is already stored in the connection kwargs. """ - if not skip_original_data_update: - # Store original values in temporary storage - original_host = self.connection_kwargs.get("host") - original_socket_timeout = self.connection_kwargs.get("socket_timeout") - original_connect_timeout = self.connection_kwargs.get( - "socket_connect_timeout" - ) - - self.connection_kwargs.update( - { - "orig_host_address": original_host, - "orig_socket_timeout": original_socket_timeout, - "orig_socket_connect_timeout": original_connect_timeout, - } - ) - # Apply temporary values as active configuration self.connection_kwargs.update({"host": tmp_host_address}) @@ -1964,9 +1957,6 @@ def remove_tmp_config_from_connection_kwargs(self): self.connection_kwargs.update( { - "orig_host_address": None, - "orig_socket_timeout": None, - "orig_socket_connect_timeout": None, "host": orig_host, "socket_timeout": orig_socket_timeout, "socket_connect_timeout": orig_connect_timeout, @@ -1997,10 +1987,7 @@ def reset_connections_tmp_settings(self): ) def update_active_connections_for_reconnect( - self, - tmp_host_address: str, - tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, + self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None ): """ Mark all active connections for reconnect. @@ -2008,18 +1995,18 @@ def update_active_connections_for_reconnect( When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - :param orig_host_address: The temporary host address to use for the connection. + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. """ for conn in self._in_use_connections: self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + conn, tmp_host_address, tmp_relax_timeout ) def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): """ Disconnect all free/available connections. @@ -2033,7 +2020,7 @@ def disconnect_and_reconfigure_free_connections( for conn in self._available_connections: self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + conn, tmp_host_address, tmp_relax_timeout ) def update_connections_current_timeout( @@ -2063,13 +2050,10 @@ def _update_connection_for_reconnect( connection: "Connection", tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): connection.mark_for_reconnect() connection.set_tmp_settings( - tmp_host_address=tmp_host_address, - tmp_relax_timeout=tmp_relax_timeout, - skip_original_data_update=skip_original_data_update, + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout ) def _disconnect_and_update_connection_for_reconnect( @@ -2077,13 +2061,10 @@ def _disconnect_and_update_connection_for_reconnect( connection: "Connection", tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): connection.disconnect() connection.set_tmp_settings( - tmp_host_address=tmp_host_address, - tmp_relax_timeout=tmp_relax_timeout, - skip_original_data_update=skip_original_data_update, + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout ) async def _mock(self, error: RedisError): @@ -2188,20 +2169,15 @@ def make_connection(self): if self._in_maintenance: self._lock.acquire() self._locked = True - # Pass current maintenance_state to new connections - maintenance_state = self.connection_kwargs.get( - "maintenance_state", MaintenanceState.NONE - ) - kwargs = dict(self.connection_kwargs) - kwargs["maintenance_state"] = maintenance_state + if self.cache is not None: connection = CacheProxyConnection( - self.connection_class(**kwargs), + self.connection_class(**self.connection_kwargs), self.cache, self._lock, ) else: - connection = self.connection_class(**kwargs) + connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection finally: @@ -2332,34 +2308,45 @@ def disconnect(self): self._locked = False def update_active_connections_for_reconnect( - self, - tmp_host_address: str, - tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, + self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ with self._lock: connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: self._update_connection_for_reconnect( - conn, - tmp_host_address, - tmp_relax_timeout, - skip_original_data_update, + conn, tmp_host_address, tmp_relax_timeout ) def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[Number] = None, - skip_original_data_update: bool = False, ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ existing_connections = self.pool.queue for conn in existing_connections: if conn: self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + conn, tmp_host_address, tmp_relax_timeout ) def update_connections_current_timeout( @@ -2367,6 +2354,15 @@ def update_connections_current_timeout( relax_timeout: Optional[float] = None, include_free_connections: bool = False, ): + """ + Update the timeout for the current socket. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param relax_timeout: The relax timeout to use for the connection. + :param include_free_connections: Whether to include available connections in the update. + """ if include_free_connections: for conn in tuple(self._connections): conn.update_current_socket_timeout(relax_timeout) @@ -2385,7 +2381,7 @@ def _update_maintenance_events_config_for_connections( def _update_maintenance_events_configs_for_connections( self, maintenance_events_pool_handler ): - """Override base class method to work with BlockingConnectionPool's structure.""" + """Update the maintenance events config for all connections in the pool.""" with self._lock: for conn in tuple(self._connections): conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) @@ -2401,12 +2397,14 @@ def reset_connections_tmp_settings(self): conn.reset_tmp_settings(reset_host_address=True, reset_relax_timeout=True) def set_in_maintenance(self, in_maintenance: bool): - """Set the maintenance mode for the connection pool.""" + """ + Sets a flag that this Blocking ConnectionPool is in maintenance mode. + + This is used to prevent new connections from being created while we are in maintenance mode. + The pool will be in maintenance mode only when we are processing a MOVING event. + """ self._in_maintenance = in_maintenance def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): for conn in self._connections: conn.maintenance_state = state - - def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): - self.connection_kwargs["maintenance_state"] = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 479a4ba090..d4b4e06231 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -382,7 +382,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent): self.pool.add_tmp_config_to_connection_kwargs( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, - skip_original_data_update=prev_moving_in_progress, ) if ( self.config.is_relax_timeouts_enabled() @@ -395,21 +394,15 @@ def handle_node_moving_event(self, event: NodeMovingEvent): if self.config.proactive_reconnect: # take care for the active connections in the pool # mark them for reconnect after they complete the current command - # skip original data update if we are already in MOVING state - # as the original data is already stored in the connection self.pool.update_active_connections_for_reconnect( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, - skip_original_data_update=prev_moving_in_progress, ) # take care for the inactive connections in the pool # delete them and create new ones - # skip original data update if we are already in MOVING state - # as the original data is already stored in the connection self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, - skip_original_data_update=prev_moving_in_progress, ) if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(False) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 880b6db27e..282aec567d 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -57,7 +57,6 @@ def test_connection_creation(self): connection_kwargs = { "foo": "bar", "biz": "baz", - "maintenance_state": MaintenanceState.NONE, } pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection @@ -160,7 +159,6 @@ def test_connection_creation(self, master_host): } pool = self.get_pool(connection_kwargs=connection_kwargs) - connection_kwargs["maintenance_state"] = MaintenanceState.NONE connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index 37ef869100..3eb648f079 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -33,25 +33,22 @@ def test_init_through_subclass(self): assert event.creation_time == 1000 assert event.expire_at == 1010 - def test_is_expired_false(self): + @pytest.mark.parametrize( + ("current_time", "expected_expired_state"), + [ + (1005, False), + (1015, True), + ], + ) + def test_is_expired(self, current_time, expected_expired_state): """Test is_expired returns False for non-expired event.""" with patch("time.monotonic", return_value=1000): event = NodeMovingEvent( id=1, new_node_host="localhost", new_node_port=6379, ttl=10 ) - with patch("time.monotonic", return_value=1005): # 5 seconds later - assert not event.is_expired() - - def test_is_expired_true(self): - """Test is_expired returns True for expired event.""" - with patch("time.monotonic", return_value=1000): - event = NodeMovingEvent( - id=1, new_node_host="localhost", new_node_port=6379, ttl=10 - ) - - with patch("time.monotonic", return_value=1015): # 15 seconds later - assert event.is_expired() + with patch("time.monotonic", return_value=current_time): + assert event.is_expired() == expected_expired_state def test_is_expired_exact_boundary(self): """Test is_expired at exact expiration boundary.""" diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index fe0b529fdb..dc4e850a50 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -320,7 +320,7 @@ def _validate_in_use_connections_state( expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, @@ -355,7 +355,7 @@ def _validate_free_connections_state( expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ): @@ -419,13 +419,16 @@ def _validate_conn_kwargs( pool.connection_kwargs["socket_connect_timeout"] == expected_socket_connect_timeout ) - assert pool.connection_kwargs["orig_host_address"] == expected_orig_host_address assert ( - pool.connection_kwargs["orig_socket_timeout"] + pool.connection_kwargs.get("orig_host_address", None) + == expected_orig_host_address + ) + assert ( + pool.connection_kwargs.get("orig_socket_timeout", None) == expected_orig_socket_timeout ) assert ( - pool.connection_kwargs["orig_socket_connect_timeout"] + pool.connection_kwargs.get("orig_socket_connect_timeout", None) == expected_orig_socket_connect_timeout ) @@ -446,7 +449,7 @@ def test_client_initialization(self): conn = test_redis_client.connection_pool.get_connection() assert conn._should_reconnect is False - assert conn.orig_host_address is None + assert conn.orig_host_address == "localhost" assert conn.orig_socket_timeout is None # Test that the node moving handler function is correctly set by @@ -825,13 +828,13 @@ def test_moving_related_events_handling_integration(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - expected_orig_socket_timeout=None, - expected_orig_socket_connect_timeout=None, expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) self._validate_disconnected(5) self._validate_connected(6) @@ -870,7 +873,7 @@ def test_moving_related_events_handling_integration(self, pool_class): expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -879,7 +882,7 @@ def test_moving_related_events_handling_integration(self, pool_class): expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, @@ -1022,7 +1025,10 @@ def test_create_new_conn_after_moving_expires(self, pool_class): new_connection = test_redis_client.connection_pool.get_connection() # Validate that new connections are created with original address (no temporary settings) - assert new_connection.orig_host_address is None + assert ( + new_connection.orig_host_address + == MockSocket.DEFAULT_ADDRESS.split(":")[0] + ) assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address assert new_connection._sock is not None @@ -1264,7 +1270,7 @@ def test_overlapping_moving_events(self, pool_class): expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1420,7 +1426,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, @@ -1434,7 +1440,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) From 602bbe963c0def217a841db4c3a70d7f84553e14 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Sat, 26 Jul 2025 13:07:58 +0300 Subject: [PATCH 16/28] Applying moving/moved only on connections to the same proxy. --- redis/_parsers/base.py | 8 +- redis/connection.py | 128 +++- redis/maintenance_events.py | 86 +-- tests/test_connection_pool.py | 1 - tests/test_maintenance_events_handling.py | 739 +++++++++++++++------- 5 files changed, 702 insertions(+), 260 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index d5e4add661..c3d4c136d2 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -202,6 +202,7 @@ def handle_push_response(self, response, **kwargs): if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # TODO: PARSE latest format when available host, port = response[2].decode().split(":") ttl = response[1] id = 1 # Hardcoded value until the notification starts including the id @@ -209,10 +210,12 @@ def handle_push_response(self, response, **kwargs): return self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available ttl = response[1] id = 2 # Hardcoded value until the notification starts including the id notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available id = 3 # Hardcoded value until the notification starts including the id notification = NodeMigratedEvent(id) else: @@ -260,6 +263,7 @@ async def handle_push_response(self, response, **kwargs): return await self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: # push notification from enterprise cluster for node moving + # TODO: PARSE latest format when available host, port = response[2].split(":") ttl = response[1] id = 1 # Hardcoded value for async parser @@ -267,10 +271,12 @@ async def handle_push_response(self, response, **kwargs): return await self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available ttl = response[1] id = 2 # Hardcoded value for async parser notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available id = 3 # Hardcoded value for async parser notification = NodeMigratedEvent(id) return await self.maintenance_push_handler_func(notification) @@ -283,7 +289,7 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func): """Set the invalidation push handler function""" self.invalidation_push_handler_func = invalidation_push_handler_func - def set_node_moving_push_handler_func(self, node_moving_push_handler_func): + def set_node_moving_push_handler(self, node_moving_push_handler_func): self.node_moving_push_handler_func = node_moving_push_handler_func def set_maintenance_push_handler(self, maintenance_push_handler_func): diff --git a/redis/connection.py b/redis/connection.py index 3f1b54e26c..0d8a3983e8 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,7 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -249,6 +249,13 @@ def maintenance_state(self, state: "MaintenanceState"): """ pass + @abstractmethod + def getpeername(self): + """ + Returns the peer name of the connection. + """ + pass + @abstractmethod def mark_for_reconnect(self): """ @@ -402,6 +409,7 @@ def __init__( if maintenance_events_config and maintenance_events_config.enabled: if maintenance_events_pool_handler: + maintenance_events_pool_handler.set_connection(self) self._parser.set_node_moving_push_handler( maintenance_events_pool_handler.handle_event ) @@ -484,6 +492,7 @@ def set_parser(self, parser_class): def set_maintenance_event_pool_handler( self, maintenance_event_pool_handler: MaintenanceEventPoolHandler ): + maintenance_event_pool_handler.set_connection(self) self._parser.set_node_moving_push_handler( maintenance_event_pool_handler.handle_event ) @@ -867,6 +876,11 @@ def maintenance_state(self) -> MaintenanceState: def maintenance_state(self, state: "MaintenanceState"): self._maintenance_state = state + def getpeername(self): + if not self._sock: + return None + return self._sock.getpeername()[0] + def mark_for_reconnect(self): self._should_reconnect = True @@ -1892,10 +1906,27 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) - def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + def set_maintenance_state_for_connections( + self, + state: "MaintenanceState", + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + ): for conn in self._available_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.maintenance_state = state for conn in self._in_use_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.maintenance_state = state def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): @@ -1963,7 +1994,12 @@ def remove_tmp_config_from_connection_kwargs(self): } ) - def reset_connections_tmp_settings(self): + def reset_connections_tmp_settings( + self, + moving_address: Optional[str] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): """ Restore original settings from temporary configuration for all connections in the pool. @@ -1978,16 +2014,25 @@ def reset_connections_tmp_settings(self): """ with self._lock: for conn in self._available_connections: + if moving_address and conn.host != moving_address: + continue conn.reset_tmp_settings( - reset_host_address=True, reset_relax_timeout=True + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, ) for conn in self._in_use_connections: + if moving_address and conn.host != moving_address: + continue conn.reset_tmp_settings( - reset_host_address=True, reset_relax_timeout=True + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, ) def update_active_connections_for_reconnect( - self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, ): """ Mark all active connections for reconnect. @@ -1999,6 +2044,8 @@ def update_active_connections_for_reconnect( :param tmp_relax_timeout: The relax timeout to use for the connection. """ for conn in self._in_use_connections: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2007,6 +2054,7 @@ def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, ): """ Disconnect all free/available connections. @@ -2019,6 +2067,8 @@ def disconnect_and_reconfigure_free_connections( """ for conn in self._available_connections: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._disconnect_and_update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2026,6 +2076,8 @@ def disconnect_and_reconfigure_free_connections( def update_connections_current_timeout( self, relax_timeout: Optional[float], + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", include_free_connections: bool = False, ): """ @@ -2039,10 +2091,22 @@ def update_connections_current_timeout( :param include_available_connections: Whether to include available connections in the update. """ for conn in self._in_use_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) if include_free_connections: for conn in self._available_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) def _update_connection_for_reconnect( @@ -2308,7 +2372,10 @@ def disconnect(self): self._locked = False def update_active_connections_for_reconnect( - self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, ): """ Mark all active connections for reconnect. @@ -2323,6 +2390,8 @@ def update_active_connections_for_reconnect( connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2331,6 +2400,7 @@ def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[Number] = None, + moving_address_src: Optional[str] = None, ): """ Disconnect all free/available connections. @@ -2345,6 +2415,8 @@ def disconnect_and_reconfigure_free_connections( for conn in existing_connections: if conn: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._disconnect_and_update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2352,6 +2424,8 @@ def disconnect_and_reconfigure_free_connections( def update_connections_current_timeout( self, relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", include_free_connections: bool = False, ): """ @@ -2365,11 +2439,23 @@ def update_connections_current_timeout( """ if include_free_connections: for conn in tuple(self._connections): + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) else: connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) def _update_maintenance_events_config_for_connections( @@ -2387,14 +2473,24 @@ def _update_maintenance_events_configs_for_connections( conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) conn.maintenance_events_config = maintenance_events_pool_handler.config - def reset_connections_tmp_settings(self): + def reset_connections_tmp_settings( + self, + moving_address: Optional[str] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): """ Override base class method to work with BlockingConnectionPool's structure. Restore original settings from temporary configuration for all connections in the pool. """ for conn in tuple(self._connections): - conn.reset_tmp_settings(reset_host_address=True, reset_relax_timeout=True) + if moving_address and conn.host != moving_address: + continue + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) def set_in_maintenance(self, in_maintenance: bool): """ @@ -2405,6 +2501,18 @@ def set_in_maintenance(self, in_maintenance: bool): """ self._in_maintenance = in_maintenance - def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + def set_maintenance_state_for_connections( + self, + state: "MaintenanceState", + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + ): for conn in self._connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.maintenance_state = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index d4b4e06231..3b83da9e02 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -324,6 +324,10 @@ def __init__( self.config = config self._processed_events = set() self._lock = threading.RLock() + self.connection = None + + def set_connection(self, connection: "ConnectionInterface"): + self.connection = connection def remove_expired_notifications(self): with self._lock: @@ -357,25 +361,20 @@ def handle_node_moving_event(self, event: NodeMovingEvent): self.config.proactive_reconnect or self.config.is_relax_timeouts_enabled() ): + moving_address_src = ( + self.connection.getpeername() if self.connection else None + ) + if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(True) - prev_moving_in_progress = False - if ( - self.pool.connection_kwargs.get("maintenance_state") - == MaintenanceState.MOVING - ): - # The pool is already in MOVING state, update just the new host information - prev_moving_in_progress = True - - if not prev_moving_in_progress: - # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) - self.pool.set_maintenance_state_for_all_connections( - MaintenanceState.MOVING - ) - self.pool.set_maintenance_state_in_connection_kwargs( - MaintenanceState.MOVING - ) + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_connections( + MaintenanceState.MOVING, moving_address_src + ) + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.MOVING + ) # edit the config for new connections until the notification expires # skip original data update if we are already in MOVING state # as the original data is already stored in the connection kwargs @@ -383,13 +382,12 @@ def handle_node_moving_event(self, event: NodeMovingEvent): tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, ) - if ( - self.config.is_relax_timeouts_enabled() - and not prev_moving_in_progress - ): + if self.config.is_relax_timeouts_enabled(): # extend the timeout for all connections that are currently in use self.pool.update_connections_current_timeout( - self.config.relax_timeout + relax_timeout=self.config.relax_timeout, + matching_address=moving_address_src, + address_type_to_match="connected", ) if self.config.proactive_reconnect: # take care for the active connections in the pool @@ -397,16 +395,18 @@ def handle_node_moving_event(self, event: NodeMovingEvent): self.pool.update_active_connections_for_reconnect( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, ) # take care for the inactive connections in the pool # delete them and create new ones self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, ) if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(False) - + print(f"Starting timer for {event} for {event.ttl} seconds") threading.Timer( event.ttl, self.handle_node_moved_event, args=(event,) ).start() @@ -415,25 +415,39 @@ def handle_node_moving_event(self, event: NodeMovingEvent): def handle_node_moved_event(self, event: NodeMovingEvent): with self._lock: - if self.pool.connection_kwargs.get("host") != event.new_node_host: - # if the current host is not matching the event - # it means there has been a new moving event after this one - # so we don't need to handle this one anymore - # the settings will be reverted by the moved handler of the next event - return - self.pool.remove_tmp_config_from_connection_kwargs() - # Clear state to NONE in kwargs immediately after updating tmp kwargs - self.pool.set_maintenance_state_in_connection_kwargs(MaintenanceState.NONE) + # if the current host in kwargs is not matching the event + # it means there has been a new moving event after this one + # and we don't need to revert the kwargs + if self.pool.connection_kwargs.get("host") == event.new_node_host: + self.pool.remove_tmp_config_from_connection_kwargs() + # Clear state to NONE in kwargs immediately after updating tmp kwargs + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.NONE + ) with self.pool._lock: - self.pool.reset_connections_tmp_settings() + moving_address = event.new_node_host if self.config.is_relax_timeouts_enabled(): + self.pool.reset_connections_tmp_settings( + moving_address, reset_relax_timeout=True + ) # reset the timeout for existing connections self.pool.update_connections_current_timeout( - relax_timeout=-1, include_free_connections=True + relax_timeout=-1, + matching_address=moving_address, + address_type_to_match="configured", + include_free_connections=True, ) - # Clear state to NONE for all connections - self.pool.set_maintenance_state_for_all_connections( - MaintenanceState.NONE + + # Clear maintenance state to NONE for all matching connections + self.pool.set_maintenance_state_for_connections( + state=MaintenanceState.NONE, + matching_address=moving_address, + address_type_to_match="configured", + ) + # reset the host address after all other operations that + # compare against tmp host are completed + self.pool.reset_connections_tmp_settings( + moving_address, reset_host_address=True ) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 282aec567d..1eb68d3775 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -9,7 +9,6 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool -from redis.maintenance_events import MaintenanceState from redis.utils import SSL_AVAILABLE from .conftest import ( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index dc4e850a50..8ea5488aa8 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1,7 +1,8 @@ import socket import threading -from typing import List +from typing import List, Union from unittest.mock import patch + import pytest from time import sleep @@ -21,13 +22,132 @@ ) +AFTER_MOVING_ADDRESS = "1.2.3.4:6379" +DEFAULT_ADDRESS = "12.45.34.56:6379" +MOVING_TIMEOUT = 1 + + +class Helpers: + """Helper class containing static methods for validation in maintenance events tests.""" + + @staticmethod + def validate_in_use_connections_state( + in_use_connections: List[AbstractConnection], + expected_state=MaintenanceState.NONE, + expected_should_reconnect: Union[bool, str] = True, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ): + """Helper method to validate state of in-use connections.""" + + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for connection in in_use_connections: + if expected_should_reconnect != "any": + assert connection._should_reconnect == expected_should_reconnect + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + if connection._sock is not None: + assert connection._sock.gettimeout() == expected_current_socket_timeout + assert connection._sock.connected is True + if expected_current_peername != "any": + assert ( + connection._sock.getpeername()[0] == expected_current_peername + ) + assert connection.maintenance_state == expected_state + + @staticmethod + def validate_free_connections_state( + pool, + should_be_connected_count=0, + connected_to_tmp_addres=False, + tmp_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_state=MaintenanceState.MOVING, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ): + """Helper method to validate state of free/available connections.""" + + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + else: + raise ValueError(f"Unsupported pool type: {type(pool)}") + + connected_count = 0 + for connection in free_connections: + assert connection._should_reconnect is False + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + assert connection.maintenance_state == expected_state + if connection._sock is not None: + assert connection._sock.connected is True + if connected_to_tmp_addres and tmp_address != "any": + assert connection._sock.getpeername()[0] == tmp_address + connected_count += 1 + assert connected_count == should_be_connected_count + + @staticmethod + def validate_conn_kwargs( + pool, + expected_host_address, + expected_port, + expected_socket_timeout, + expected_socket_connect_timeout, + expected_orig_host_address, + expected_orig_socket_timeout, + expected_orig_socket_connect_timeout, + ): + """Helper method to validate connection kwargs.""" + assert pool.connection_kwargs["host"] == expected_host_address + assert pool.connection_kwargs["port"] == expected_port + assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout + assert ( + pool.connection_kwargs["socket_connect_timeout"] + == expected_socket_connect_timeout + ) + assert ( + pool.connection_kwargs.get("orig_host_address", None) + == expected_orig_host_address + ) + assert ( + pool.connection_kwargs.get("orig_socket_timeout", None) + == expected_orig_socket_timeout + ) + assert ( + pool.connection_kwargs.get("orig_socket_connect_timeout", None) + == expected_orig_socket_connect_timeout + ) + + class MockSocket: """Mock socket that simulates Redis protocol responses.""" - AFTER_MOVING_ADDRESS = "1.2.3.4:6379" - DEFAULT_ADDRESS = "12.45.34.56:6379" - MOVING_TIMEOUT = 1 - def __init__(self): self.connected = False self.address = None @@ -73,7 +193,7 @@ def send(self, data): # MOVING push message before SET key_receive_moving_X response # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) # Note: Using + instead of $ to send as simple string instead of bulk string - moving_push = f">3\r\n$6\r\nMOVING\r\n:{MockSocket.MOVING_TIMEOUT}\r\n+{MockSocket.AFTER_MOVING_ADDRESS}\r\n" + moving_push = f">3\r\n$6\r\nMOVING\r\n:{MOVING_TIMEOUT}\r\n+{AFTER_MOVING_ADDRESS}\r\n" response = moving_push.encode() + response self.pending_responses.append(response) @@ -164,7 +284,7 @@ def shutdown(self, how): pass -class TestMaintenanceEventsHandling: +class TestMaintenanceEventsHandlingSingleProxy: """Integration tests for maintenance events handling with real connection pool.""" def setup_method(self): @@ -233,8 +353,8 @@ def _get_client( ) test_pool = pool_class( - host=MockSocket.DEFAULT_ADDRESS.split(":")[0], - port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + host=DEFAULT_ADDRESS.split(":")[0], + port=int(DEFAULT_ADDRESS.split(":")[1]), max_connections=max_connections, protocol=3, # Required for maintenance events maintenance_events_config=config, @@ -313,124 +433,12 @@ def _validate_connected(self, expected_count): connected_sockets_count += 1 assert connected_sockets_count == expected_count - def _validate_in_use_connections_state( - self, - in_use_connections: List[AbstractConnection], - expected_state=MaintenanceState.NONE, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_socket_timeout=None, - expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_orig_socket_timeout=None, - expected_orig_socket_connect_timeout=None, - expected_current_socket_timeout=None, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], - ): - """Helper method to validate state of in-use connections.""" - # validate in use connections are still working with set flag for reconnect - # and timeout is updated - for connection in in_use_connections: - assert connection._should_reconnect is True - assert connection.host == expected_host_address - assert connection.socket_timeout == expected_socket_timeout - assert connection.socket_connect_timeout == expected_socket_connect_timeout - assert connection.orig_host_address == expected_orig_host_address - assert connection.orig_socket_timeout == expected_orig_socket_timeout - assert ( - connection.orig_socket_connect_timeout - == expected_orig_socket_connect_timeout - ) - if connection._sock is not None: - assert connection._sock.gettimeout() == expected_current_socket_timeout - assert connection._sock.connected is True - assert connection._sock.getpeername()[0] == expected_current_peername - assert connection.maintenance_state == expected_state - - def _validate_free_connections_state( - self, - pool, - should_be_connected_count, - connected_to_tmp_addres=False, - expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_socket_timeout=None, - expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_orig_socket_timeout=None, - expected_orig_socket_connect_timeout=None, - ): - """Helper method to validate state of free/available connections.""" - if isinstance(pool, BlockingConnectionPool): - free_connections = [conn for conn in pool.pool.queue if conn is not None] - elif isinstance(pool, ConnectionPool): - free_connections = pool._available_connections - else: - raise ValueError(f"Unsupported pool type: {type(pool)}") - - connected_count = 0 - for connection in free_connections: - assert connection._should_reconnect is False - assert connection.host == expected_host_address - assert connection.socket_timeout == expected_socket_timeout - assert connection.socket_connect_timeout == expected_socket_connect_timeout - assert connection.orig_host_address == expected_orig_host_address - assert connection.orig_socket_timeout == expected_orig_socket_timeout - assert ( - connection.orig_socket_connect_timeout - == expected_orig_socket_connect_timeout - ) - assert connection.maintenance_state == expected_state - if connection._sock is not None: - assert connection._sock.connected is True - if connected_to_tmp_addres: - assert ( - connection._sock.getpeername()[0] - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) - connected_count += 1 - assert connected_count == should_be_connected_count - def _validate_all_timeouts(self, expected_timeout): """Helper method to validate state of in-use connections.""" # validate in use connections are still working with set flag for reconnect # and timeout is updated for mock_socket in self.mock_sockets: - if expected_timeout is None: - assert mock_socket.gettimeout() is None - else: - assert mock_socket.gettimeout() == expected_timeout - - def _validate_conn_kwargs( - self, - pool, - expected_host_address, - expected_port, - expected_socket_timeout, - expected_socket_connect_timeout, - expected_orig_host_address, - expected_orig_socket_timeout, - expected_orig_socket_connect_timeout, - ): - """Helper method to validate connection kwargs.""" - assert pool.connection_kwargs["host"] == expected_host_address - assert pool.connection_kwargs["port"] == expected_port - assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout - assert ( - pool.connection_kwargs["socket_connect_timeout"] - == expected_socket_connect_timeout - ) - assert ( - pool.connection_kwargs.get("orig_host_address", None) - == expected_orig_host_address - ) - assert ( - pool.connection_kwargs.get("orig_socket_timeout", None) - == expected_orig_socket_timeout - ) - assert ( - pool.connection_kwargs.get("orig_socket_connect_timeout", None) - == expected_orig_socket_connect_timeout - ) + assert mock_socket.gettimeout() == expected_timeout def test_client_initialization(self): """Test that Redis client is created with maintenance events configuration.""" @@ -826,63 +834,75 @@ def test_moving_related_events_handling_integration(self, pool_class): assert result2 is True, "Command 2 (SET key_receive_moving) failed" # Validate pool and connections settings were updated according to MOVING event - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) self._validate_disconnected(5) self._validate_connected(6) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[ + expected_current_peername=DEFAULT_ADDRESS.split(":")[ 0 ], # the in use connections reconnect when they complete their current task ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=test_redis_client.connection_pool, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, connected_to_tmp_addres=True, ) # Wait for MOVING timeout to expire and the moving completed handler to run - sleep(MockSocket.MOVING_TIMEOUT + 0.5) - self._validate_all_timeouts(None) - self._validate_conn_kwargs( + sleep(MOVING_TIMEOUT + 0.5) + + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_host_address=DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, @@ -936,13 +956,13 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): assert result is True, "SET key_receive_moving command failed" # Validate pool and connections settings were updated according to MOVING event - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, ) @@ -959,14 +979,14 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): # Validate that new connections are created with temporary address and relax timeout # and when connecting those configs are used # get_connection() returns a connection that is already connected - assert new_connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + assert new_connection.host == AFTER_MOVING_ADDRESS.split(":")[0] assert new_connection.socket_timeout is self.config.relax_timeout # New connections should be connected to the temporary address assert new_connection._sock is not None assert new_connection._sock.connected is True assert ( new_connection._sock.getpeername()[0] - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + == AFTER_MOVING_ADDRESS.split(":")[0] ) assert new_connection._sock.gettimeout() == self.config.relax_timeout @@ -1014,7 +1034,7 @@ def test_create_new_conn_after_moving_expires(self, pool_class): assert result is True, "SET key_receive_moving command failed" # Wait for MOVING timeout to expire - sleep(MockSocket.MOVING_TIMEOUT + 0.5) + sleep(MOVING_TIMEOUT + 0.5) # Now get several new connections after expiration old_connections = [] @@ -1025,10 +1045,7 @@ def test_create_new_conn_after_moving_expires(self, pool_class): new_connection = test_redis_client.connection_pool.get_connection() # Validate that new connections are created with original address (no temporary settings) - assert ( - new_connection.orig_host_address - == MockSocket.DEFAULT_ADDRESS.split(":")[0] - ) + assert new_connection.orig_host_address == DEFAULT_ADDRESS.split(":")[0] assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address assert new_connection._sock is not None @@ -1087,13 +1104,13 @@ def test_receive_migrated_after_moving(self, pool_class): assert result_moving is True, "SET key_receive_moving command failed" # Validate pool and connections settings were updated according to MOVING event - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, ) @@ -1113,13 +1130,13 @@ def test_receive_migrated_after_moving(self, pool_class): # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) # MOVING settings should still be active # MOVING timeout should still be active - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, ) @@ -1133,7 +1150,7 @@ def test_receive_migrated_after_moving(self, pool_class): # Validate that new connections are created with MOVING settings (still active) for connection in new_connections: - assert connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + assert connection.host == AFTER_MOVING_ADDRESS.split(":")[0] # Note: New connections may not inherit the exact relax timeout value # but they should have the temporary host address # New connections should be connected @@ -1158,13 +1175,19 @@ def test_overlapping_moving_events(self, pool_class): Test handling of overlapping/duplicate MOVING events (e.g., two MOVING events before the first expires). Ensures that the second MOVING event updates the pool and connections as expected, and that expiry/cleanup works. """ + global AFTER_MOVING_ADDRESS test_redis_client = self._get_client( pool_class, max_connections=5, setup_pool_handler=True ) try: # Create and release some connections + in_use_connections = [] for _ in range(3): - conn = test_redis_client.connection_pool.get_connection() + in_use_connections.append( + test_redis_client.connection_pool.get_connection() + ) + + for conn in in_use_connections: test_redis_client.connection_pool.release(conn) # Take 2 connections to be in use @@ -1178,99 +1201,106 @@ def test_overlapping_moving_events(self, pool_class): value_moving1 = "value3_0" result1 = test_redis_client.set(key_moving1, value_moving1) assert result1 is True - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the first MOVING event - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=test_redis_client.connection_pool, should_be_connected_count=1, connected_to_tmp_addres=True, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) + # Reconnect in use connections + for conn in in_use_connections: + conn.disconnect() + conn.connect() # Before the first MOVING expires, trigger a second MOVING event (simulate new address) # Validate the orig properties are not changed! second_moving_address = "5.6.7.8:6380" - orig_after_moving = MockSocket.AFTER_MOVING_ADDRESS - MockSocket.AFTER_MOVING_ADDRESS = second_moving_address + orig_after_moving = AFTER_MOVING_ADDRESS + # Temporarily modify the global constant for this test + AFTER_MOVING_ADDRESS = second_moving_address try: key_moving2 = "key_receive_moving_1" value_moving2 = "value3_1" result2 = test_redis_client.set(key_moving2, value_moving2) assert result2 is True - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, expected_host_address=second_moving_address.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the second MOVING event - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=second_moving_address.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=orig_after_moving.split(":")[0], ) - self._validate_free_connections_state( + # print(test_redis_client.connection_pool._available_connections) + Helpers.validate_free_connections_state( test_redis_client.connection_pool, should_be_connected_count=1, connected_to_tmp_addres=True, + tmp_address=second_moving_address.split(":")[0], expected_state=MaintenanceState.MOVING, expected_host_address=second_moving_address.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) finally: - MockSocket.AFTER_MOVING_ADDRESS = orig_after_moving + AFTER_MOVING_ADDRESS = orig_after_moving # Wait for both MOVING timeouts to expire - sleep(MockSocket.MOVING_TIMEOUT + 0.5) - self._validate_conn_kwargs( + sleep(MOVING_TIMEOUT + 0.5) + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1309,13 +1339,13 @@ def worker(idx): assert all(results), f"Not all threads succeeded: {results}" assert not errors, f"Errors occurred in threads: {errors}" # After all threads, MOVING event should have been handled safely - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1356,19 +1386,19 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): id=1, new_node_host=tmp_address, new_node_port=6379, ttl=1 ) pool_handler.handle_event(moving_event) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, @@ -1376,7 +1406,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1386,17 +1416,17 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): conn._maintenance_event_connection_handler.handle_event( NodeMigratingEvent(id=2, ttl=1) ) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) # 3. MIGRATED event (simulate direct connection handler call) @@ -1405,42 +1435,42 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): NodeMigratedEvent(id=2) ) # State should not change for connections that are in MOVING state - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) # 4. MOVED event (simulate timer expiry) pool_handler.handle_node_moved_event(moving_event) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.NONE, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_host_address=DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, expected_state=MaintenanceState.NONE, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_host_address=DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1453,3 +1483,288 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): pool.release(conn) if hasattr(pool, "disconnect"): pool.disconnect() + + +class TestMaintenanceEventsHandlingMultipleProxies: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + self.orig_host = "test.address.com" + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if hasattr(sock, "pending_responses") and sock.pending_responses: + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + ips = ["1.2.3.4", "5.6.7.8", "9.10.11.12"] + ips = ips * 3 + + # Mock socket creation to return our mock sockets + def mock_socket_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + if host == self.orig_host: + ip_address = ips.pop(0) + else: + ip_address = host + + # Return the standard getaddrinfo format + # (family, type, proto, canonname, sockaddr) + return [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + (ip_address, port), + ) + ] + + self.getaddrinfo_patcher = patch( + "socket.getaddrinfo", side_effect=mock_socket_getaddrinfo + ) + self.getaddrinfo_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + self.getaddrinfo_patcher.stop() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_after_moving_multiple_proxies(self, pool_class): + """ """ + # Setup + + pool = pool_class( + host=self.orig_host, + port=12345, + max_connections=10, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + pool.set_maintenance_events_pool_handler( + MaintenanceEventPoolHandler(pool, self.config) + ) + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + key1 = "1.2.3.4" + key2 = "5.6.7.8" + key3 = "9.10.11.12" + in_use_connections = {key1: [], key2: [], key3: []} + # Create 7 connections + for _ in range(7): + conn = pool.get_connection() + in_use_connections[conn.getpeername()].append(conn) + + for _, conns in in_use_connections.items(): + while len(conns) > 1: + pool.release(conns.pop()) + + # Send MOVING event to con with ip = key1 + conn = in_use_connections[key1][0] + pool_handler.set_connection(conn) + new_ip = "13.14.15.16" + pool_handler.handle_event( + NodeMovingEvent(id=1, new_node_host=new_ip, new_node_port=6379, ttl=1) + ) + + # validate in use connection and ip1 + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key1, + ) + # validate free connections for ip1 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + else: + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.host == self.orig_host + assert conn.socket_timeout is None + assert conn.socket_connect_timeout is None + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + assert changed_free_connections == 2 + assert len(free_connections) == 4 + + # Send second MOVING event to con with ip = key2 + conn = in_use_connections[key2][0] + pool_handler.set_connection(conn) + new_ip_2 = "17.18.19.20" + pool_handler.handle_event( + NodeMovingEvent(id=2, new_node_host=new_ip_2, new_node_port=6379, ttl=2) + ) + + # validate in use connection and ip2 + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + # validate free connections for ip2 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip_2: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip_2 + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + # here I can't validate the other connections since some of + # them are in MOVING state from the first event + # and some are in NONE state + assert changed_free_connections == 1 + + # MIGRATING event on connection that has already been marked as MOVING + conn = in_use_connections[key2][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection does not lose its MOVING state + assert conn.maintenance_state == MaintenanceState.MOVING + # MIGRATED event + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection does not lose its MOVING state and relax timeout + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.socket_timeout == self.config.relax_timeout + + # Send Migrating event to con with ip = key3 + conn = in_use_connections[key3][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection is in MIGRATING state + assert conn.maintenance_state == MaintenanceState.MIGRATING + assert conn.socket_timeout == self.config.relax_timeout + + # Send MIGRATED event to con with ip = key3 + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection is in MOVING state + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.socket_timeout is None + + # sleep to expire only the first MOVING events + sleep(1.3) + # validate only the connections affected by the first MOVING event + # have lost their MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.NONE, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key1, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key3], + expected_state=MaintenanceState.NONE, + expected_should_reconnect=False, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key3, + ) + # TODO validate free connections + + # sleep to expire the second MOVING events + sleep(1) + # validate all connections have lost their MOVING state + Helpers.validate_in_use_connections_state( + [ + *in_use_connections[key1], + *in_use_connections[key2], + *in_use_connections[key3], + ], + expected_state=MaintenanceState.NONE, + expected_should_reconnect="any", + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername="any", + ) + # TODO validate free connections From 953b41aee62324ccffa540eac8b0aafed745d0c2 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 8 Aug 2025 17:59:52 +0300 Subject: [PATCH 17/28] Applying review comments. --- redis/asyncio/connection.py | 2 - redis/connection.py | 92 +++++++++++------------ redis/maintenance_events.py | 2 +- tests/test_maintenance_events_handling.py | 16 ++-- 4 files changed, 51 insertions(+), 61 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index fe86e4c36e..4efd868f6f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1308,8 +1308,6 @@ def __init__( ) self._condition = asyncio.Condition() self.timeout = timeout - self._in_maintenance = False - self._locked = False @deprecated_args( args_to_warn=["*"], diff --git a/redis/connection.py b/redis/connection.py index 0d8a3983e8..22a9845b23 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1906,6 +1906,20 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) + def should_update_connection( + self, + conn: "Connection", + address_type_to_match: Literal["connected", "configured"] = "connected", + matching_address: Optional[str] = None, + ) -> bool: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + return False + else: + if matching_address and conn.host != matching_address: + return False + return True + def set_maintenance_state_for_connections( self, state: "MaintenanceState", @@ -1913,21 +1927,15 @@ def set_maintenance_state_for_connections( address_type_to_match: Literal["connected", "configured"] = "connected", ): for conn in self._available_connections: - if address_type_to_match == "connected": - if matching_address and conn.getpeername() != matching_address: - continue - else: - if matching_address and conn.host != matching_address: - continue - conn.maintenance_state = state + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + conn.maintenance_state = state for conn in self._in_use_connections: - if address_type_to_match == "connected": - if matching_address and conn.getpeername() != matching_address: - continue - else: - if matching_address and conn.host != matching_address: - continue - conn.maintenance_state = state + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + conn.maintenance_state = state def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): self.connection_kwargs["maintenance_state"] = state @@ -2091,23 +2099,17 @@ def update_connections_current_timeout( :param include_available_connections: Whether to include available connections in the update. """ for conn in self._in_use_connections: - if address_type_to_match == "connected": - if matching_address and conn.getpeername() != matching_address: - continue - else: - if matching_address and conn.host != matching_address: - continue - conn.update_current_socket_timeout(relax_timeout) + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + conn.update_current_socket_timeout(relax_timeout) if include_free_connections: for conn in self._available_connections: - if address_type_to_match == "connected": - if matching_address and conn.getpeername() != matching_address: - continue - else: - if matching_address and conn.host != matching_address: - continue - conn.update_current_socket_timeout(relax_timeout) + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + conn.update_current_socket_timeout(relax_timeout) def _update_connection_for_reconnect( self, @@ -2439,24 +2441,18 @@ def update_connections_current_timeout( """ if include_free_connections: for conn in tuple(self._connections): - if address_type_to_match == "connected": - if matching_address and conn.getpeername() != matching_address: - continue - else: - if matching_address and conn.host != matching_address: - continue - conn.update_current_socket_timeout(relax_timeout) + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + conn.update_current_socket_timeout(relax_timeout) else: connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: - if address_type_to_match == "connected": - if matching_address and conn.getpeername() != matching_address: - continue - else: - if matching_address and conn.host != matching_address: - continue - conn.update_current_socket_timeout(relax_timeout) + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + conn.update_current_socket_timeout(relax_timeout) def _update_maintenance_events_config_for_connections( self, maintenance_events_config @@ -2508,11 +2504,7 @@ def set_maintenance_state_for_connections( address_type_to_match: Literal["connected", "configured"] = "connected", ): for conn in self._connections: - if address_type_to_match == "connected": - if matching_address and conn.getpeername() != matching_address: - continue - else: - if matching_address and conn.host != matching_address: - continue - - conn.maintenance_state = state + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + conn.maintenance_state = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 3b83da9e02..25530d5674 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -406,7 +406,7 @@ def handle_node_moving_event(self, event: NodeMovingEvent): ) if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(False) - print(f"Starting timer for {event} for {event.ttl} seconds") + threading.Timer( event.ttl, self.handle_node_moved_event, args=(event,) ).start() diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 8ea5488aa8..c0c98ff330 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -73,7 +73,7 @@ def validate_in_use_connections_state( def validate_free_connections_state( pool, should_be_connected_count=0, - connected_to_tmp_addres=False, + connected_to_tmp_address=False, tmp_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_state=MaintenanceState.MOVING, expected_host_address=DEFAULT_ADDRESS.split(":")[0], @@ -107,7 +107,7 @@ def validate_free_connections_state( assert connection.maintenance_state == expected_state if connection._sock is not None: assert connection._sock.connected is True - if connected_to_tmp_addres and tmp_address != "any": + if connected_to_tmp_address and tmp_address != "any": assert connection._sock.getpeername()[0] == tmp_address connected_count += 1 assert connected_count == should_be_connected_count @@ -870,7 +870,7 @@ def test_moving_related_events_handling_integration(self, pool_class): expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, - connected_to_tmp_addres=True, + connected_to_tmp_address=True, ) # Wait for MOVING timeout to expire and the moving completed handler to run sleep(MOVING_TIMEOUT + 0.5) @@ -906,7 +906,7 @@ def test_moving_related_events_handling_integration(self, pool_class): expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, - connected_to_tmp_addres=True, + connected_to_tmp_address=True, expected_state=MaintenanceState.NONE, ) finally: @@ -1227,7 +1227,7 @@ def test_overlapping_moving_events(self, pool_class): Helpers.validate_free_connections_state( pool=test_redis_client.connection_pool, should_be_connected_count=1, - connected_to_tmp_addres=True, + connected_to_tmp_address=True, expected_state=MaintenanceState.MOVING, expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, @@ -1279,7 +1279,7 @@ def test_overlapping_moving_events(self, pool_class): Helpers.validate_free_connections_state( test_redis_client.connection_pool, should_be_connected_count=1, - connected_to_tmp_addres=True, + connected_to_tmp_address=True, tmp_address=second_moving_address.split(":")[0], expected_state=MaintenanceState.MOVING, expected_host_address=second_moving_address.split(":")[0], @@ -1401,7 +1401,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): Helpers.validate_free_connections_state( pool=pool, should_be_connected_count=0, - connected_to_tmp_addres=False, + connected_to_tmp_address=False, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, @@ -1465,7 +1465,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): Helpers.validate_free_connections_state( pool=pool, should_be_connected_count=0, - connected_to_tmp_addres=False, + connected_to_tmp_address=False, expected_state=MaintenanceState.NONE, expected_host_address=DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, From 2210fedfff424b9e919b0aaef6b2bd6258e789c8 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 11 Aug 2025 16:39:01 +0300 Subject: [PATCH 18/28] Refactor to have less methods in pool classes and made some of the existing ones more generic --- redis/connection.py | 307 +++++++++++++----------------------- redis/maintenance_events.py | 117 +++++++------- 2 files changed, 171 insertions(+), 253 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 22a9845b23..b855c7dc8a 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1920,122 +1920,88 @@ def should_update_connection( return False return True - def set_maintenance_state_for_connections( + def update_connection_settings( self, - state: "MaintenanceState", - matching_address: Optional[str] = None, - address_type_to_match: Literal["connected", "configured"] = "connected", - ): - for conn in self._available_connections: - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - conn.maintenance_state = state - for conn in self._in_use_connections: - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - conn.maintenance_state = state - - def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): - self.connection_kwargs["maintenance_state"] = state - - def add_tmp_config_to_connection_kwargs( - self, - tmp_host_address: str, - tmp_relax_timeout: Optional[float] = None, + conn: "Connection", + state: Optional["MaintenanceState"] = None, + relax_timeout: Optional[float] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, ): """ - Store original connection configuration and apply temporary settings. - - This method saves the current host, socket_timeout, and socket_connect_timeout values - in temporary storage fields (orig_*), then applies the provided temporary values - as the active connection configuration. - - This is used when a cluster node is rebound to a different address during - maintenance operations. New connections created after this call will use the - temporary configuration until remove_tmp_config_from_connection_kwargs() is called. - - When this method is called the pool will already be locked, so getting the pool - lock inside is not needed. - - :param tmp_host_address: The temporary host address to use for new connections. - This parameter is required and will replace the current host. - :param tmp_relax_timeout: The temporary timeout value to use for both socket_timeout - and socket_connect_timeout. If -1 is provided, the timeout - settings are not modified (relax timeout is disabled). + Update the settings for a single connection. """ - # Apply temporary values as active configuration - self.connection_kwargs.update({"host": tmp_host_address}) + if state: + conn.maintenance_state = state - if tmp_relax_timeout != -1: - self.connection_kwargs.update( - { - "socket_timeout": tmp_relax_timeout, - "socket_connect_timeout": tmp_relax_timeout, - } + if reset_relax_timeout or reset_host_address: + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, ) - def remove_tmp_config_from_connection_kwargs(self): - """ - Remove temporary configuration from connection kwargs and restore original values. - - This method restores the original host address, socket timeout, and connect timeout - from their temporary storage back to the main connection kwargs, then clears the - temporary storage fields. - - This is typically called when a cluster node maintenance operation is complete - and the connection should revert to its original configuration. + conn.update_current_socket_timeout(relax_timeout) - When this method is called the pool will already be locked, so getting the pool - lock inside is not needed. - """ - orig_host = self.connection_kwargs.get("orig_host_address") - orig_socket_timeout = self.connection_kwargs.get("orig_socket_timeout") - orig_connect_timeout = self.connection_kwargs.get("orig_socket_connect_timeout") - - self.connection_kwargs.update( - { - "host": orig_host, - "socket_timeout": orig_socket_timeout, - "socket_connect_timeout": orig_connect_timeout, - } - ) - - def reset_connections_tmp_settings( + def update_connections_settings( self, - moving_address: Optional[str] = None, + state: Optional["MaintenanceState"] = None, + relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", reset_host_address: bool = False, reset_relax_timeout: bool = False, + include_free_connections: bool = True, ): """ - Restore original settings from temporary configuration for all connections in the pool. + Update the settings for all matching connections in the pool. - This method restores each connection's original host, socket_timeout, and socket_connect_timeout - values from their orig_* attributes back to the active connection configuration, then clears - the temporary storage attributes. + This method does not create new connections. + This method does not affect the connection kwargs. - This is used to restore connections to their original configuration after maintenance operations - that required temporary address/timeout changes are complete. - - When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + :param state: The maintenance state to set for the connection. + :param relax_timeout: The relax timeout to set for the connection. + :param matching_address: The address to match for the connection. + :param address_type_to_match: The type of address to match. + :param reset_host_address: Whether to reset the host address to the original address. + :param reset_relax_timeout: Whether to reset the relax timeout to the original timeout. """ - with self._lock: - for conn in self._available_connections: - if moving_address and conn.host != moving_address: - continue - conn.reset_tmp_settings( - reset_host_address=reset_host_address, - reset_relax_timeout=reset_relax_timeout, - ) - for conn in self._in_use_connections: - if moving_address and conn.host != moving_address: - continue - conn.reset_tmp_settings( + for conn in self._in_use_connections: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, reset_host_address=reset_host_address, reset_relax_timeout=reset_relax_timeout, ) + if include_free_connections: + for conn in self._available_connections: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + + def update_connection_kwargs( + self, + **kwargs, + ): + """ + Update the connection kwargs for all future connections. + + This method updates the connection kwargs for all future connections created by the pool. + Existing connections are not affected. + """ + self.connection_kwargs.update(kwargs) + def update_active_connections_for_reconnect( self, tmp_host_address: str, @@ -2052,11 +2018,12 @@ def update_active_connections_for_reconnect( :param tmp_relax_timeout: The relax timeout to use for the connection. """ for conn in self._in_use_connections: - if moving_address_src and conn.getpeername() != moving_address_src: - continue - self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout - ) + if self.should_update_connection( + conn, "connected", moving_address_src + ): + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) def disconnect_and_reconfigure_free_connections( self, @@ -2075,41 +2042,12 @@ def disconnect_and_reconfigure_free_connections( """ for conn in self._available_connections: - if moving_address_src and conn.getpeername() != moving_address_src: - continue - self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout - ) - - def update_connections_current_timeout( - self, - relax_timeout: Optional[float], - matching_address: Optional[str] = None, - address_type_to_match: Literal["connected", "configured"] = "connected", - include_free_connections: bool = False, - ): - """ - Update the timeout either for all connections in the pool or just for the ones in use. - This is used when a cluster node is migrated to a different address. - - When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - - :param relax_timeout: The relax timeout to use for the connection. - If -1 is provided - the relax timeout is disabled. - :param include_available_connections: Whether to include available connections in the update. - """ - for conn in self._in_use_connections: if self.should_update_connection( - conn, address_type_to_match, matching_address + conn, "connected", moving_address_src ): - conn.update_current_socket_timeout(relax_timeout) - - if include_free_connections: - for conn in self._available_connections: - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - conn.update_current_socket_timeout(relax_timeout) + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) def _update_connection_for_reconnect( self, @@ -2373,6 +2311,46 @@ def disconnect(self): pass self._locked = False + def update_connections_settings( + self, + state: Optional["MaintenanceState"] = None, + relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + include_free_connections: bool = True, + ): + """ + Override base class method to work with BlockingConnectionPool's structure. + """ + if include_free_connections: + for conn in tuple(self._connections): + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + def update_active_connections_for_reconnect( self, tmp_host_address: str, @@ -2423,37 +2401,6 @@ def disconnect_and_reconfigure_free_connections( conn, tmp_host_address, tmp_relax_timeout ) - def update_connections_current_timeout( - self, - relax_timeout: Optional[float] = None, - matching_address: Optional[str] = None, - address_type_to_match: Literal["connected", "configured"] = "connected", - include_free_connections: bool = False, - ): - """ - Update the timeout for the current socket. - This is used when a cluster node is migrated to a different address. - - When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - - :param relax_timeout: The relax timeout to use for the connection. - :param include_free_connections: Whether to include available connections in the update. - """ - if include_free_connections: - for conn in tuple(self._connections): - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - conn.update_current_socket_timeout(relax_timeout) - else: - connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - conn.update_current_socket_timeout(relax_timeout) - def _update_maintenance_events_config_for_connections( self, maintenance_events_config ): @@ -2469,25 +2416,6 @@ def _update_maintenance_events_configs_for_connections( conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) conn.maintenance_events_config = maintenance_events_pool_handler.config - def reset_connections_tmp_settings( - self, - moving_address: Optional[str] = None, - reset_host_address: bool = False, - reset_relax_timeout: bool = False, - ): - """ - Override base class method to work with BlockingConnectionPool's structure. - - Restore original settings from temporary configuration for all connections in the pool. - """ - for conn in tuple(self._connections): - if moving_address and conn.host != moving_address: - continue - conn.reset_tmp_settings( - reset_host_address=reset_host_address, - reset_relax_timeout=reset_relax_timeout, - ) - def set_in_maintenance(self, in_maintenance: bool): """ Sets a flag that this Blocking ConnectionPool is in maintenance mode. @@ -2497,14 +2425,3 @@ def set_in_maintenance(self, in_maintenance: bool): """ self._in_maintenance = in_maintenance - def set_maintenance_state_for_connections( - self, - state: "MaintenanceState", - matching_address: Optional[str] = None, - address_type_to_match: Literal["connected", "configured"] = "connected", - ): - for conn in self._connections: - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - conn.maintenance_state = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 25530d5674..ef5cfd3f03 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -368,44 +368,50 @@ def handle_node_moving_event(self, event: NodeMovingEvent): if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(True) - # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) - self.pool.set_maintenance_state_for_connections( - MaintenanceState.MOVING, moving_address_src - ) - self.pool.set_maintenance_state_in_connection_kwargs( - MaintenanceState.MOVING - ) - # edit the config for new connections until the notification expires - # skip original data update if we are already in MOVING state - # as the original data is already stored in the connection kwargs - self.pool.add_tmp_config_to_connection_kwargs( - tmp_host_address=event.new_node_host, - tmp_relax_timeout=self.config.relax_timeout, + # Update connection settings for all connections + self.pool.update_connections_settings( + state=MaintenanceState.MOVING, + relax_timeout=self.config.relax_timeout, + matching_address=moving_address_src, + address_type_to_match="connected", + include_free_connections=True, ) - if self.config.is_relax_timeouts_enabled(): - # extend the timeout for all connections that are currently in use - self.pool.update_connections_current_timeout( - relax_timeout=self.config.relax_timeout, - matching_address=moving_address_src, - address_type_to_match="connected", - ) + if self.config.proactive_reconnect: - # take care for the active connections in the pool - # mark them for reconnect after they complete the current command - self.pool.update_active_connections_for_reconnect( - tmp_host_address=event.new_node_host, - tmp_relax_timeout=self.config.relax_timeout, - moving_address_src=moving_address_src, - ) - # take care for the inactive connections in the pool - # delete them and create new ones - self.pool.disconnect_and_reconfigure_free_connections( - tmp_host_address=event.new_node_host, - tmp_relax_timeout=self.config.relax_timeout, - moving_address_src=moving_address_src, + # take care for the active connections in the pool + # mark them for reconnect after they complete the current command + self.pool.update_active_connections_for_reconnect( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) + # take care for the inactive connections in the pool + # delete them and create new ones + self.pool.disconnect_and_reconfigure_free_connections( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) + + # Update config for new connections: + # Set state to MOVING + # update host + # if relax timeouts are enabled - update timeouts + kwargs: dict = { + "maintenance_state": MaintenanceState.MOVING, + "host": event.new_node_host, + } + if self.config.is_relax_timeouts_enabled(): + kwargs.update( + { + "socket_timeout": self.config.relax_timeout, + "socket_connect_timeout": self.config.relax_timeout, + } ) - if getattr(self.pool, "set_in_maintenance", False): - self.pool.set_in_maintenance(False) + self.pool.update_connection_kwargs(**kwargs) + + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(False) threading.Timer( event.ttl, self.handle_node_moved_event, args=(event,) @@ -419,35 +425,30 @@ def handle_node_moved_event(self, event: NodeMovingEvent): # it means there has been a new moving event after this one # and we don't need to revert the kwargs if self.pool.connection_kwargs.get("host") == event.new_node_host: - self.pool.remove_tmp_config_from_connection_kwargs() - # Clear state to NONE in kwargs immediately after updating tmp kwargs - self.pool.set_maintenance_state_in_connection_kwargs( - MaintenanceState.NONE - ) + orig_host = self.pool.connection_kwargs.get("orig_host_address") + orig_socket_timeout = self.pool.connection_kwargs.get("orig_socket_timeout") + orig_connect_timeout = self.pool.connection_kwargs.get("orig_socket_connect_timeout") + kwargs: dict = { + "maintenance_state": MaintenanceState.NONE, + "host": orig_host, + "socket_timeout": orig_socket_timeout, + "socket_connect_timeout": orig_connect_timeout, + } + self.pool.update_connection_kwargs(**kwargs) + with self.pool._lock: moving_address = event.new_node_host - if self.config.is_relax_timeouts_enabled(): - self.pool.reset_connections_tmp_settings( - moving_address, reset_relax_timeout=True - ) - # reset the timeout for existing connections - self.pool.update_connections_current_timeout( - relax_timeout=-1, - matching_address=moving_address, - address_type_to_match="configured", - include_free_connections=True, - ) + reset_relax_timeout = self.config.is_relax_timeouts_enabled() + reset_host_address = self.config.proactive_reconnect - # Clear maintenance state to NONE for all matching connections - self.pool.set_maintenance_state_for_connections( + self.pool.update_connections_settings( + relax_timeout=-1, state=MaintenanceState.NONE, matching_address=moving_address, address_type_to_match="configured", - ) - # reset the host address after all other operations that - # compare against tmp host are completed - self.pool.reset_connections_tmp_settings( - moving_address, reset_host_address=True + reset_relax_timeout=reset_relax_timeout, + reset_host_address=reset_host_address, + include_free_connections=True, ) From 1427d9936a31cebd272eb5d76a553db7fd4391be Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 11 Aug 2025 16:49:13 +0300 Subject: [PATCH 19/28] Fixing lint errors --- redis/connection.py | 27 +++++++++++---------------- redis/maintenance_events.py | 36 ++++++++++++++++++++---------------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index b855c7dc8a..89152a6932 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2018,9 +2018,7 @@ def update_active_connections_for_reconnect( :param tmp_relax_timeout: The relax timeout to use for the connection. """ for conn in self._in_use_connections: - if self.should_update_connection( - conn, "connected", moving_address_src - ): + if self.should_update_connection(conn, "connected", moving_address_src): self._update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2042,9 +2040,7 @@ def disconnect_and_reconfigure_free_connections( """ for conn in self._available_connections: - if self.should_update_connection( - conn, "connected", moving_address_src - ): + if self.should_update_connection(conn, "connected", moving_address_src): self._disconnect_and_update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2312,15 +2308,15 @@ def disconnect(self): self._locked = False def update_connections_settings( - self, - state: Optional["MaintenanceState"] = None, - relax_timeout: Optional[float] = None, - matching_address: Optional[str] = None, - address_type_to_match: Literal["connected", "configured"] = "connected", - reset_host_address: bool = False, - reset_relax_timeout: bool = False, - include_free_connections: bool = True, - ): + self, + state: Optional["MaintenanceState"] = None, + relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + include_free_connections: bool = True, + ): """ Override base class method to work with BlockingConnectionPool's structure. """ @@ -2424,4 +2420,3 @@ def set_in_maintenance(self, in_maintenance: bool): The pool will be in maintenance mode only when we are processing a MOVING event. """ self._in_maintenance = in_maintenance - diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index ef5cfd3f03..8c6c15c74a 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -378,20 +378,20 @@ def handle_node_moving_event(self, event: NodeMovingEvent): ) if self.config.proactive_reconnect: - # take care for the active connections in the pool - # mark them for reconnect after they complete the current command - self.pool.update_active_connections_for_reconnect( - tmp_host_address=event.new_node_host, - tmp_relax_timeout=self.config.relax_timeout, - moving_address_src=moving_address_src, - ) - # take care for the inactive connections in the pool - # delete them and create new ones - self.pool.disconnect_and_reconfigure_free_connections( - tmp_host_address=event.new_node_host, - tmp_relax_timeout=self.config.relax_timeout, - moving_address_src=moving_address_src, - ) + # take care for the active connections in the pool + # mark them for reconnect after they complete the current command + self.pool.update_active_connections_for_reconnect( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) + # take care for the inactive connections in the pool + # delete them and create new ones + self.pool.disconnect_and_reconfigure_free_connections( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) # Update config for new connections: # Set state to MOVING @@ -426,8 +426,12 @@ def handle_node_moved_event(self, event: NodeMovingEvent): # and we don't need to revert the kwargs if self.pool.connection_kwargs.get("host") == event.new_node_host: orig_host = self.pool.connection_kwargs.get("orig_host_address") - orig_socket_timeout = self.pool.connection_kwargs.get("orig_socket_timeout") - orig_connect_timeout = self.pool.connection_kwargs.get("orig_socket_connect_timeout") + orig_socket_timeout = self.pool.connection_kwargs.get( + "orig_socket_timeout" + ) + orig_connect_timeout = self.pool.connection_kwargs.get( + "orig_socket_connect_timeout" + ) kwargs: dict = { "maintenance_state": MaintenanceState.NONE, "host": orig_host, From a2744f3fe9f6f5bef695061fffadb86cd509b49a Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 11 Aug 2025 17:35:52 +0300 Subject: [PATCH 20/28] Fixing tests --- tests/test_maintenance_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index 3eb648f079..c90fa5db4f 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -443,7 +443,7 @@ def test_handle_node_moving_event_success(self): assert event in self.handler._processed_events # Verify pool methods were called - self.mock_pool.add_tmp_config_to_connection_kwargs.assert_called_once() + self.mock_pool.update_connections_settings.assert_called_once() def test_handle_node_moved_event(self): """Test handling of node moved event (cleanup).""" @@ -454,7 +454,7 @@ def test_handle_node_moved_event(self): self.handler.handle_node_moved_event(event) # Verify cleanup methods were called - self.mock_pool.remove_tmp_config_from_connection_kwargs.assert_called_once() + self.mock_pool.update_connections_settings.assert_called_once() class TestMaintenanceEventConnectionHandler: From 260b34e2ed15fdb5107ad28ea3b0e90b6ac59452 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 14 Aug 2025 10:55:09 +0300 Subject: [PATCH 21/28] Fixing the docs of some of the new methods in connection pools. Handle better retry_on_error handling on connection initialization. --- redis/connection.py | 42 ++++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 89152a6932..effe447161 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,18 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -311,7 +322,7 @@ def __init__( socket_timeout: Optional[float] = None, socket_connect_timeout: Optional[float] = None, retry_on_timeout: bool = False, - retry_on_error=SENTINEL, + retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, encoding: str = "utf-8", encoding_errors: str = "strict", decode_responses: bool = False, @@ -367,19 +378,22 @@ def __init__( self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: - retry_on_error = [] + retry_on_errors_list = [] + else: + retry_on_errors_list = list(retry_on_error) if retry_on_timeout: # Add TimeoutError to the errors list to retry on - retry_on_error.append(TimeoutError) - self.retry_on_error = retry_on_error - if retry or retry_on_error: + retry_on_errors_list.append(TimeoutError) + self.retry_on_error = retry_on_errors_list + if retry or self.retry_on_error: if retry is None: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) + if self.retry_on_error: + # Update the retry's supported errors with the specified errors + self.retry.update_supported_errors(self.retry_on_error) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval @@ -1912,6 +1926,9 @@ def should_update_connection( address_type_to_match: Literal["connected", "configured"] = "connected", matching_address: Optional[str] = None, ) -> bool: + """ + Check if the connection should be updated based on the matching address. + """ if address_type_to_match == "connected": if matching_address and conn.getpeername() != matching_address: return False @@ -1964,6 +1981,7 @@ def update_connections_settings( :param address_type_to_match: The type of address to match. :param reset_host_address: Whether to reset the host address to the original address. :param reset_relax_timeout: Whether to reset the relax timeout to the original timeout. + :param include_free_connections: Whether to include free/available connections. """ for conn in self._in_use_connections: if self.should_update_connection( @@ -2016,6 +2034,7 @@ def update_active_connections_for_reconnect( :param tmp_host_address: The temporary host address to use for the connection. :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. """ for conn in self._in_use_connections: if self.should_update_connection(conn, "connected", moving_address_src): @@ -2035,8 +2054,9 @@ def disconnect_and_reconfigure_free_connections( When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - :param orig_host_address: The temporary host address to use for the connection. - :param orig_relax_timeout: The relax timeout to use for the connection. + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. """ for conn in self._available_connections: @@ -2361,6 +2381,7 @@ def update_active_connections_for_reconnect( :param tmp_host_address: The temporary host address to use for the connection. :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. """ with self._lock: connections_in_queue = {conn for conn in self.pool.queue if conn} @@ -2386,6 +2407,7 @@ def disconnect_and_reconfigure_free_connections( :param tmp_host_address: The temporary host address to use for the connection. :param tmp_relax_timeout: The relax timeout to use for the connection. + :param moving_address_src: The address of the node that is being moved. """ existing_connections = self.pool.queue From 4c6eb445832bb4ae9b90626355e4928c01168320 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 15 Aug 2025 16:06:49 +0300 Subject: [PATCH 22/28] Applying review comments --- redis/_parsers/base.py | 6 +- redis/client.py | 6 +- redis/connection.py | 114 ++++++++++++---------- tests/test_maintenance_events_handling.py | 2 +- 4 files changed, 70 insertions(+), 58 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index c3d4c136d2..dd2d8b9de0 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -1,7 +1,7 @@ import sys from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError -from typing import Callable, List, Optional, Protocol, Union +from typing import Awaitable, Callable, List, Optional, Protocol, Union from redis.maintenance_events import ( NodeMigratedEvent, @@ -243,8 +243,8 @@ class AsyncPushNotificationsParser(Protocol): pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None - node_moving_push_handler_func: Optional[Callable] = None - maintenance_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None + maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None async def handle_pubsub_push_response(self, response): """Handle pubsub push responses asynchronously""" diff --git a/redis/client.py b/redis/client.py index a6c96c3882..26837b673b 100755 --- a/redis/client.py +++ b/redis/client.py @@ -954,12 +954,13 @@ def _execute(self, conn, command, *args, **kwargs): patterns we were previously listening to """ + if conn.should_reconnect(): + self._reconnect(conn) + response = conn.retry.call_with_retry( lambda: command(*args, **kwargs), lambda _: self._reconnect(conn), ) - if conn.should_reconnect(): - self._reconnect(conn) return response @@ -1172,6 +1173,7 @@ def get_message( return None response = self.parse_response(block=(timeout is None), timeout=timeout) + if response: return self.handle_message(response, ignore_subscribe_messages) return None diff --git a/redis/connection.py b/redis/connection.py index effe447161..1389f77476 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -445,6 +445,8 @@ def __init__( if orig_socket_connect_timeout else self.socket_connect_timeout ) + else: + self._maintenance_event_connection_handler = None self._should_reconnect = False self.maintenance_state = maintenance_state @@ -511,8 +513,8 @@ def set_maintenance_event_pool_handler( maintenance_event_pool_handler.handle_event ) - # Initialize maintenance event connection handler if it doesn't exist - if not hasattr(self, "_maintenance_event_connection_handler"): + # Update maintenance event connection handler if it doesn't exist + if not self._maintenance_event_connection_handler: self._maintenance_event_connection_handler = ( MaintenanceEventConnectionHandler( self, maintenance_event_pool_handler.config @@ -521,6 +523,10 @@ def set_maintenance_event_pool_handler( self._parser.set_maintenance_push_handler( self._maintenance_event_connection_handler.handle_event ) + else: + self._maintenance_event_connection_handler.config = ( + maintenance_event_pool_handler.config + ) def connect(self): "Connects to the Redis server if not already connected" @@ -1983,20 +1989,8 @@ def update_connections_settings( :param reset_relax_timeout: Whether to reset the relax timeout to the original timeout. :param include_free_connections: Whether to include free/available connections. """ - for conn in self._in_use_connections: - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - self.update_connection_settings( - conn, - state=state, - relax_timeout=relax_timeout, - reset_host_address=reset_host_address, - reset_relax_timeout=reset_relax_timeout, - ) - - if include_free_connections: - for conn in self._available_connections: + with self._lock: + for conn in self._in_use_connections: if self.should_update_connection( conn, address_type_to_match, matching_address ): @@ -2008,6 +2002,19 @@ def update_connections_settings( reset_relax_timeout=reset_relax_timeout, ) + if include_free_connections: + for conn in self._available_connections: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + def update_connection_kwargs( self, **kwargs, @@ -2036,11 +2043,12 @@ def update_active_connections_for_reconnect( :param tmp_relax_timeout: The relax timeout to use for the connection. :param moving_address_src: The address of the node that is being moved. """ - for conn in self._in_use_connections: - if self.should_update_connection(conn, "connected", moving_address_src): - self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout - ) + with self._lock: + for conn in self._in_use_connections: + if self.should_update_connection(conn, "connected", moving_address_src): + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) def disconnect_and_reconfigure_free_connections( self, @@ -2058,12 +2066,12 @@ def disconnect_and_reconfigure_free_connections( :param tmp_relax_timeout: The relax timeout to use for the connection. :param moving_address_src: The address of the node that is being moved. """ - - for conn in self._available_connections: - if self.should_update_connection(conn, "connected", moving_address_src): - self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout - ) + with self._lock: + for conn in self._available_connections: + if self.should_update_connection(conn, "connected", moving_address_src): + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) def _update_connection_for_reconnect( self, @@ -2340,22 +2348,9 @@ def update_connections_settings( """ Override base class method to work with BlockingConnectionPool's structure. """ - if include_free_connections: - for conn in tuple(self._connections): - if self.should_update_connection( - conn, address_type_to_match, matching_address - ): - self.update_connection_settings( - conn, - state=state, - relax_timeout=relax_timeout, - reset_host_address=reset_host_address, - reset_relax_timeout=reset_relax_timeout, - ) - else: - connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: + with self._lock: + if include_free_connections: + for conn in tuple(self._connections): if self.should_update_connection( conn, address_type_to_match, matching_address ): @@ -2366,6 +2361,20 @@ def update_connections_settings( reset_host_address=reset_host_address, reset_relax_timeout=reset_relax_timeout, ) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if self.should_update_connection( + conn, address_type_to_match, matching_address + ): + self.update_connection_settings( + conn, + state=state, + relax_timeout=relax_timeout, + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) def update_active_connections_for_reconnect( self, @@ -2409,15 +2418,16 @@ def disconnect_and_reconfigure_free_connections( :param tmp_relax_timeout: The relax timeout to use for the connection. :param moving_address_src: The address of the node that is being moved. """ - existing_connections = self.pool.queue - - for conn in existing_connections: - if conn: - if moving_address_src and conn.getpeername() != moving_address_src: - continue - self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout - ) + with self._lock: + existing_connections = self.pool.queue + + for conn in existing_connections: + if conn: + if moving_address_src and conn.getpeername() != moving_address_src: + continue + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) def _update_maintenance_events_config_for_connections( self, maintenance_events_config diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index c0c98ff330..8db8d182a7 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -503,7 +503,7 @@ def test_maint_handler_init_for_existing_connections(self): # Verify that maintenance events are initially disabled assert existing_conn._parser.node_moving_push_handler_func is None - assert not hasattr(existing_conn, "_maintenance_event_connection_handler") + assert existing_conn._maintenance_event_connection_handler is None assert existing_conn._parser.maintenance_push_handler_func is None # Create a new enabled configuration and set up pool handler From 10ded3498e6bc7d660d5b5c249e30d56bed949d2 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 24 Jul 2025 18:05:59 +0300 Subject: [PATCH 23/28] Adding handling of FAILING_OVER and FAILED_OVER events/push notifications --- redis/maintenance_events.py | 116 ++++++++++++++- tests/test_maintenance_events.py | 171 +++++++++++++++++++--- tests/test_maintenance_events_handling.py | 124 +++++++++++++++- 3 files changed, 379 insertions(+), 32 deletions(-) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 8c6c15c74a..dbf7fe3ebb 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -12,6 +12,7 @@ class MaintenanceState(enum.Enum): NONE = "none" MOVING = "moving" MIGRATING = "migrating" + FAILING_OVER = "failing_over" if TYPE_CHECKING: @@ -261,6 +262,105 @@ def __hash__(self) -> int: return hash((self.__class__, self.id)) +class NodeFailingOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of failing over. + + This event is received when a node starts a failover process during + cluster maintenance operations or when handling node failures. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailingOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailingOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class NodeFailedOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed a failover. + + This event is received when a node has finished the failover process + during cluster maintenance operations or after handling node failures. + + Args: + id (int): Unique identifier for this event + """ + + DEFAULT_TTL = 5 + + def __init__(self, id: int): + super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailedOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailedOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + class MaintenanceEventsConfig: """ Configuration class for maintenance events handling behaviour. Events are received through @@ -465,24 +565,28 @@ def __init__( def handle_event(self, event: MaintenanceEvent): if isinstance(event, NodeMigratingEvent): - return self.handle_migrating_event(event) + return self.handle_maintenance_start_event(MaintenanceState.MIGRATING) elif isinstance(event, NodeMigratedEvent): - return self.handle_migration_completed_event(event) + return self.handle_maintenance_completed_event() + elif isinstance(event, NodeFailingOverEvent): + return self.handle_maintenance_start_event(MaintenanceState.FAILING_OVER) + elif isinstance(event, NodeFailedOverEvent): + return self.handle_maintenance_completed_event() else: logging.error(f"Unhandled event type: {event}") - def handle_migrating_event(self, notification: NodeMigratingEvent): + def handle_maintenance_start_event(self, maintenance_state: MaintenanceState): if ( self.connection.maintenance_state == MaintenanceState.MOVING or not self.config.is_relax_timeouts_enabled() ): return - self.connection.maintenance_state = MaintenanceState.MIGRATING + self.connection.maintenance_state = maintenance_state self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) - def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + def handle_maintenance_completed_event(self): # Only reset timeouts if state is not MOVING and relax timeouts are enabled if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -490,7 +594,7 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): ): return self.connection.reset_tmp_settings(reset_relax_timeout=True) - # Node migration completed - reset the connection + # Maintenance completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index c90fa5db4f..badcc6da6c 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -7,9 +7,12 @@ NodeMovingEvent, NodeMigratingEvent, NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventsConfig, MaintenanceEventPoolHandler, MaintenanceEventConnectionHandler, + MaintenanceState, ) @@ -281,6 +284,84 @@ def test_equality_and_hash(self): assert hash(event1) != hash(event3) +class TestNodeFailingOverEvent: + """Test the NodeFailingOverEvent class.""" + + def test_init(self): + """Test NodeFailingOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeFailingOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeFailingOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailingOverEvent.""" + event1 = NodeFailingOverEvent(id=1, ttl=5) + event2 = NodeFailingOverEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeFailingOverEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeFailedOverEvent: + """Test the NodeFailedOverEvent class.""" + + def test_init(self): + """Test NodeFailedOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeFailedOverEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeFailedOverEvent.DEFAULT_TTL == 5 + event = NodeFailedOverEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeFailedOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeFailedOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailedOverEvent.""" + event1 = NodeFailedOverEvent(id=1) + event2 = NodeFailedOverEvent(id=1) # Same id + event3 = NodeFailedOverEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + class TestMaintenanceEventsConfig: """Test the MaintenanceEventsConfig class.""" @@ -477,19 +558,41 @@ def test_handle_event_migrating(self): """Test handling of NodeMigratingEvent.""" event = NodeMigratingEvent(id=1, ttl=5) - with patch.object(self.handler, "handle_migrating_event") as mock_handle: + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with(MaintenanceState.MIGRATING) def test_handle_event_migrated(self): """Test handling of NodeMigratedEvent.""" event = NodeMigratedEvent(id=1) with patch.object( - self.handler, "handle_migration_completed_event" + self.handler, "handle_maintenance_completed_event" ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with() + + def test_handle_event_failing_over(self): + """Test handling of NodeFailingOverEvent.""" + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(MaintenanceState.FAILING_OVER) + + def test_handle_event_failed_over(self): + """Test handling of NodeFailedOverEvent.""" + event = NodeFailedOverEvent(id=1) + + with patch.object( + self.handler, "handle_maintenance_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with() def test_handle_event_unknown_type(self): """Test handling of unknown event type.""" @@ -500,43 +603,71 @@ def test_handle_event_unknown_type(self): result = self.handler.handle_event(event) assert result is None - def test_handle_migrating_event_disabled(self): - """Test migrating event handling when relax timeouts are disabled.""" + def test_handle_maintenance_start_event_disabled(self): + """Test maintenance start event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratingEvent(id=1, ttl=5) - result = handler.handle_migrating_event(event) + result = handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migrating_event_success(self): - """Test successful migrating event handling.""" - event = NodeMigratingEvent(id=1, ttl=5) + def test_handle_maintenance_start_event_moving_state(self): + """Test maintenance start event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING - self.handler.handle_migrating_event(event) + result = self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + def test_handle_maintenance_start_event_migrating_success(self): + """Test successful maintenance start event handling for migrating.""" + self.mock_connection.maintenance_state = MaintenanceState.NONE + + self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) + + assert self.mock_connection.maintenance_state == MaintenanceState.MIGRATING self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) self.mock_connection.set_tmp_settings.assert_called_once_with( tmp_relax_timeout=20 ) - def test_handle_migration_completed_event_disabled(self): - """Test migration completed event handling when relax timeouts are disabled.""" + def test_handle_maintenance_start_event_failing_over_success(self): + """Test successful maintenance start event handling for failing over.""" + self.mock_connection.maintenance_state = MaintenanceState.NONE + + self.handler.handle_maintenance_start_event(MaintenanceState.FAILING_OVER) + + assert self.mock_connection.maintenance_state == MaintenanceState.FAILING_OVER + self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) + self.mock_connection.set_tmp_settings.assert_called_once_with( + tmp_relax_timeout=20 + ) + + def test_handle_maintenance_completed_event_disabled(self): + """Test maintenance completed event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratedEvent(id=1) - result = handler.handle_migration_completed_event(event) + result = handler.handle_maintenance_completed_event() assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migration_completed_event_success(self): - """Test successful migration completed event handling.""" - event = NodeMigratedEvent(id=1) + def test_handle_maintenance_completed_event_moving_state(self): + """Test maintenance completed event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING + + result = self.handler.handle_maintenance_completed_event() + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_maintenance_completed_event_success(self): + """Test successful maintenance completed event handling.""" + self.mock_connection.maintenance_state = MaintenanceState.MIGRATING - self.handler.handle_migration_completed_event(event) + self.handler.handle_maintenance_completed_event() + assert self.mock_connection.maintenance_state == MaintenanceState.NONE self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) self.mock_connection.reset_tmp_settings.assert_called_once_with( reset_relax_timeout=True diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 8db8d182a7..cecd99b000 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -16,9 +16,11 @@ from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, + NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventPoolHandler, NodeMovingEvent, - NodeMigratedEvent, ) @@ -189,6 +191,22 @@ def send(self, data): # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" response = migrated_push.encode() + response + elif ( + b"key_receive_failing_over_" in data + or b"key_receive_failing_over" in data + ): + # FAILING_OVER push message before SET key_receive_failing_over_X response + # Format: >2\r\n$12\r\nFAILING_OVER\r\n:10\r\n (2 elements: FAILING_OVER, ttl) + failing_over_push = ">2\r\n$12\r\nFAILING_OVER\r\n:10\r\n" + response = failing_over_push.encode() + response + elif ( + b"key_receive_failed_over_" in data + or b"key_receive_failed_over" in data + ): + # FAILED_OVER push message before SET key_receive_failed_over_X response + # Format: >1\r\n$11\r\nFAILED_OVER\r\n (1 element: FAILED_OVER) + failed_over_push = ">1\r\n$11\r\nFAILED_OVER\r\n" + response = failed_over_push.encode() + response elif b"key_receive_moving_" in data: # MOVING push message before SET key_receive_moving_X response # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) @@ -211,6 +229,10 @@ def send(self, data): self.pending_responses.append(b"$6\r\nvalue2\r\n") elif b"key_receive_migrated" in data: self.pending_responses.append(b"$6\r\nvalue3\r\n") + elif b"key_receive_failing_over" in data: + self.pending_responses.append(b"$6\r\nvalue4\r\n") + elif b"key_receive_failed_over" in data: + self.pending_responses.append(b"$6\r\nvalue5\r\n") elif b"key1" in data: self.pending_responses.append(b"$6\r\nvalue1\r\n") else: @@ -727,13 +749,14 @@ def test_migration_related_events_handling_integration(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migrating_event_with_disabled_relax_timeout(self, pool_class): """ - Test migrating event handling when relax timeout is disabled. + Test maintenance events handling when relax timeout is disabled. This test validates that when relax_timeout is disabled (-1): - 1. MIGRATING events are received and processed + 1. MIGRATING, MIGRATED, FAILING_OVER, and FAILED_OVER events are received and processed 2. No timeout updates are applied to connections - 3. Socket timeouts remain unchanged during migration events + 3. Socket timeouts remain unchanged during all maintenance events 4. Tests both ConnectionPool and BlockingConnectionPool implementations + 5. Tests the complete lifecycle: MIGRATING -> MIGRATED -> FAILING_OVER -> FAILED_OVER """ # Create config with disabled relax timeout disabled_config = MaintenanceEventsConfig( @@ -776,6 +799,57 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" ) + # Command 4: This SET command will receive MIGRATED push message before response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result4 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 4 result + assert result4 is True, "Command 4 (SET key_receive_migrated) failed" + + # Validate timeout is still NOT updated after MIGRATED (relax is disabled) + self._validate_current_timeout(None) + + # Command 5: This SET command will receive FAILING_OVER push message before response + key_failing_over = "key_receive_failing_over" + value_failing_over = "value4" + result5 = test_redis_client.set(key_failing_over, value_failing_over) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_failing_over) failed" + + # Validate timeout is still NOT updated after FAILING_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 6: Another command to verify timeout remains unchanged during failover + result6 = test_redis_client.get(key_failing_over) + + # Validate Command 6 result + expected_value6 = value_failing_over.encode() + assert result6 == expected_value6, ( + f"Command 6 (GET key_receive_failing_over) failed. Expected: {expected_value6}, Got: {result6}" + ) + + # Command 7: This SET command will receive FAILED_OVER push message before response + key_failed_over = "key_receive_failed_over" + value_failed_over = "value5" + result7 = test_redis_client.set(key_failed_over, value_failed_over) + + # Validate Command 7 result + assert result7 is True, "Command 7 (SET key_receive_failed_over) failed" + + # Validate timeout is still NOT updated after FAILED_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 8: Final command to verify timeout remains unchanged after all events + result8 = test_redis_client.get(key_failed_over) + + # Validate Command 8 result + expected_value8 = value_failed_over.encode() + assert result8 == expected_value8, ( + f"Command 8 (GET key_receive_failed_over) failed. Expected: {expected_value8}, Got: {result8}" + ) + # Verify maintenance events were processed correctly # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( @@ -1357,7 +1431,7 @@ def worker(idx): def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): """ Test moving configs are not lost if the per connection events get picked up after moving is handled. - MOVING → MIGRATING → MIGRATED → MOVED + MOVING → MIGRATING → MIGRATED → FAILING_OVER → FAILED_OVER → MOVED Checks the state after each event for all connections and for new connections created during each state. """ # Setup @@ -1448,7 +1522,45 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - # 4. MOVED event (simulate timer expiry) + # 4. FAILING_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailingOverEvent(id=3, ttl=1) + ) + # State should not change for connections that are in MOVING state + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 5. FAILED_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailedOverEvent(id=3) + ) + # State should not change for connections that are in MOVING state + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 6. MOVED event (simulate timer expiry) pool_handler.handle_node_moved_event(moving_event) Helpers.validate_in_use_connections_state( in_use_connections, From b9afaf06839496d6d5c031366c40716b995c5a07 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 18 Aug 2025 11:50:44 +0300 Subject: [PATCH 24/28] Fixing tests after merging with feature branch --- tests/test_maintenance_events_handling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index cecd99b000..ad4963a86f 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1528,17 +1528,17 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): NodeFailingOverEvent(id=3, ttl=1) ) # State should not change for connections that are in MOVING state - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) # 5. FAILED_OVER event (simulate direct connection handler call) @@ -1547,17 +1547,17 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): NodeFailedOverEvent(id=3) ) # State should not change for connections that are in MOVING state - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) # 6. MOVED event (simulate timer expiry) From 058be2c1ead0b573f5c7b00e8f1e6ffa09023c0f Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 18 Aug 2025 11:56:20 +0300 Subject: [PATCH 25/28] Fixing lint errors. --- redis/maintenance_events.py | 2 -- tests/test_maintenance_events.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index bf15a3c1c5..958948ee33 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -15,7 +15,6 @@ class MaintenanceState(enum.Enum): FAILING_OVER = "failing_over" - if TYPE_CHECKING: from redis.connection import ( BlockingConnectionPool, @@ -588,7 +587,6 @@ def handle_maintenance_start_event(self, maintenance_state: MaintenanceState): # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) - def handle_maintenance_completed_event(self): # Only reset timeouts if state is not MOVING and relax timeouts are enabled if ( diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index 2be284d89e..b28a9dc29f 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -283,6 +283,7 @@ def test_equality_and_hash(self): assert hash(event1) == hash(event2) assert hash(event1) != hash(event3) + class TestNodeFailingOverEvent: """Test the NodeFailingOverEvent class.""" @@ -360,6 +361,7 @@ def test_equality_and_hash(self): assert hash(event1) == hash(event2) assert hash(event1) != hash(event3) + class TestMaintenanceEventsConfig: """Test the MaintenanceEventsConfig class.""" @@ -562,7 +564,6 @@ def test_handle_event_migrating(self): self.handler.handle_event(event) mock_handle.assert_called_once_with(MaintenanceState.MIGRATING) - def test_handle_event_migrated(self): """Test handling of NodeMigratedEvent.""" event = NodeMigratedEvent(id=1) From e4a86460e40e73f2b60cab13d7be92a9d1d6104c Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 18 Aug 2025 11:57:53 +0300 Subject: [PATCH 26/28] Update tests/test_maintenance_events_handling.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/test_maintenance_events_handling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index ad4963a86f..341ed06ef7 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1431,7 +1431,8 @@ def worker(idx): def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): """ Test moving configs are not lost if the per connection events get picked up after moving is handled. - MOVING → MIGRATING → MIGRATED → FAILING_OVER → FAILED_OVER → MOVED + Sequence of events: MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER, MOVED. + Note: FAILING_OVER and FAILED_OVER events do not change the connection state when already in MOVING state. Checks the state after each event for all connections and for new connections created during each state. """ # Setup From 51d24ba98c24fca9e0ecd1e6dfc7717276916a60 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 18 Aug 2025 16:03:35 +0300 Subject: [PATCH 27/28] Applying review comments --- redis/maintenance_events.py | 29 ++++++++++++++-------- tests/test_maintenance_events.py | 30 ++++++++--------------- tests/test_maintenance_events_handling.py | 2 +- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 958948ee33..951d64e717 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -11,7 +11,7 @@ class MaintenanceState(enum.Enum): NONE = "none" MOVING = "moving" - MIGRATING = "migrating" + MAINTENANCE = "maintenance" FAILING_OVER = "failing_over" @@ -557,6 +557,14 @@ def handle_node_moved_event(self, event: NodeMovingEvent): class MaintenanceEventConnectionHandler: + # 1 = "starting maintenance" events, 0 = "completed maintenance" events + _EVENT_TYPES: dict[type["MaintenanceEvent"], int] = { + NodeMigratingEvent: 1, + NodeFailingOverEvent: 1, + NodeMigratedEvent: 0, + NodeFailedOverEvent: 0, + } + def __init__( self, connection: "ConnectionInterface", config: MaintenanceEventsConfig ) -> None: @@ -564,16 +572,17 @@ def __init__( self.config = config def handle_event(self, event: MaintenanceEvent): - if isinstance(event, NodeMigratingEvent): - return self.handle_maintenance_start_event(MaintenanceState.MIGRATING) - elif isinstance(event, NodeMigratedEvent): - return self.handle_maintenance_completed_event() - elif isinstance(event, NodeFailingOverEvent): - return self.handle_maintenance_start_event(MaintenanceState.FAILING_OVER) - elif isinstance(event, NodeFailedOverEvent): - return self.handle_maintenance_completed_event() - else: + # get the event type by checking its class in the _EVENT_TYPES dict + event_type = self._EVENT_TYPES.get(event.__class__) + + if event_type is None: logging.error(f"Unhandled event type: {event}") + return + + if event_type: + self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) + else: + self.handle_maintenance_completed_event() def handle_maintenance_start_event(self, maintenance_state: MaintenanceState): if ( diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index b28a9dc29f..30169615cf 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -562,7 +562,7 @@ def test_handle_event_migrating(self): self.handler, "handle_maintenance_start_event" ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(MaintenanceState.MIGRATING) + mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) def test_handle_event_migrated(self): """Test handling of NodeMigratedEvent.""" @@ -582,7 +582,7 @@ def test_handle_event_failing_over(self): self.handler, "handle_maintenance_start_event" ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(MaintenanceState.FAILING_OVER) + mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) def test_handle_event_failed_over(self): """Test handling of NodeFailedOverEvent.""" @@ -608,7 +608,7 @@ def test_handle_maintenance_start_event_disabled(self): config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - result = handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) + result = handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() @@ -616,29 +616,19 @@ def test_handle_maintenance_start_event_moving_state(self): """Test maintenance start event handling when connection is in MOVING state.""" self.mock_connection.maintenance_state = MaintenanceState.MOVING - result = self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) + result = self.handler.handle_maintenance_start_event( + MaintenanceState.MAINTENANCE + ) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_maintenance_start_event_migrating_success(self): + def test_handle_maintenance_start_event_success(self): """Test successful maintenance start event handling for migrating.""" self.mock_connection.maintenance_state = MaintenanceState.NONE - self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) - - assert self.mock_connection.maintenance_state == MaintenanceState.MIGRATING - self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) - self.mock_connection.set_tmp_settings.assert_called_once_with( - tmp_relax_timeout=20 - ) - - def test_handle_maintenance_start_event_failing_over_success(self): - """Test successful maintenance start event handling for failing over.""" - self.mock_connection.maintenance_state = MaintenanceState.NONE - - self.handler.handle_maintenance_start_event(MaintenanceState.FAILING_OVER) + self.handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) - assert self.mock_connection.maintenance_state == MaintenanceState.FAILING_OVER + assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) self.mock_connection.set_tmp_settings.assert_called_once_with( tmp_relax_timeout=20 @@ -663,7 +653,7 @@ def test_handle_maintenance_completed_event_moving_state(self): def test_handle_maintenance_completed_event_success(self): """Test successful maintenance completed event handling.""" - self.mock_connection.maintenance_state = MaintenanceState.MIGRATING + self.mock_connection.maintenance_state = MaintenanceState.MAINTENANCE self.handler.handle_maintenance_completed_event() diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 341ed06ef7..ea0021c8a5 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1808,7 +1808,7 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): conn_event_handler = conn._maintenance_event_connection_handler conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) # validate connection is in MIGRATING state - assert conn.maintenance_state == MaintenanceState.MIGRATING + assert conn.maintenance_state == MaintenanceState.MAINTENANCE assert conn.socket_timeout == self.config.relax_timeout # Send MIGRATED event to con with ip = key3 From 66c1fe0c54affaedfb20c8a3fc05df8744d13036 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 18 Aug 2025 16:23:58 +0300 Subject: [PATCH 28/28] Applying review comments --- redis/maintenance_events.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 951d64e717..09c767be63 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -12,7 +12,6 @@ class MaintenanceState(enum.Enum): NONE = "none" MOVING = "moving" MAINTENANCE = "maintenance" - FAILING_OVER = "failing_over" if TYPE_CHECKING: @@ -573,7 +572,7 @@ def __init__( def handle_event(self, event: MaintenanceEvent): # get the event type by checking its class in the _EVENT_TYPES dict - event_type = self._EVENT_TYPES.get(event.__class__) + event_type = self._EVENT_TYPES.get(event.__class__, None) if event_type is None: logging.error(f"Unhandled event type: {event}")