diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 8c6c15c74a..09c767be63 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -11,7 +11,7 @@ class MaintenanceState(enum.Enum): NONE = "none" MOVING = "moving" - MIGRATING = "migrating" + MAINTENANCE = "maintenance" if TYPE_CHECKING: @@ -261,6 +261,105 @@ def __hash__(self) -> int: return hash((self.__class__, self.id)) +class NodeFailingOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of failing over. + + This event is received when a node starts a failover process during + cluster maintenance operations or when handling node failures. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailingOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailingOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class NodeFailedOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed a failover. + + This event is received when a node has finished the failover process + during cluster maintenance operations or after handling node failures. + + Args: + id (int): Unique identifier for this event + """ + + DEFAULT_TTL = 5 + + def __init__(self, id: int): + super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailedOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailedOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + class MaintenanceEventsConfig: """ Configuration class for maintenance events handling behaviour. Events are received through @@ -457,6 +556,14 @@ def handle_node_moved_event(self, event: NodeMovingEvent): class MaintenanceEventConnectionHandler: + # 1 = "starting maintenance" events, 0 = "completed maintenance" events + _EVENT_TYPES: dict[type["MaintenanceEvent"], int] = { + NodeMigratingEvent: 1, + NodeFailingOverEvent: 1, + NodeMigratedEvent: 0, + NodeFailedOverEvent: 0, + } + def __init__( self, connection: "ConnectionInterface", config: MaintenanceEventsConfig ) -> None: @@ -464,25 +571,31 @@ def __init__( self.config = config def handle_event(self, event: MaintenanceEvent): - if isinstance(event, NodeMigratingEvent): - return self.handle_migrating_event(event) - elif isinstance(event, NodeMigratedEvent): - return self.handle_migration_completed_event(event) - else: + # get the event type by checking its class in the _EVENT_TYPES dict + event_type = self._EVENT_TYPES.get(event.__class__, None) + + if event_type is None: logging.error(f"Unhandled event type: {event}") + return - def handle_migrating_event(self, notification: NodeMigratingEvent): + if event_type: + self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) + else: + self.handle_maintenance_completed_event() + + def handle_maintenance_start_event(self, maintenance_state: MaintenanceState): if ( self.connection.maintenance_state == MaintenanceState.MOVING or not self.config.is_relax_timeouts_enabled() ): return - self.connection.maintenance_state = MaintenanceState.MIGRATING + + self.connection.maintenance_state = maintenance_state self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) - def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + def handle_maintenance_completed_event(self): # Only reset timeouts if state is not MOVING and relax timeouts are enabled if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -490,7 +603,7 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): ): return self.connection.reset_tmp_settings(reset_relax_timeout=True) - # Node migration completed - reset the connection + # Maintenance completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index c90fa5db4f..30169615cf 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -7,9 +7,12 @@ NodeMovingEvent, NodeMigratingEvent, NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventsConfig, MaintenanceEventPoolHandler, MaintenanceEventConnectionHandler, + MaintenanceState, ) @@ -281,6 +284,84 @@ def test_equality_and_hash(self): assert hash(event1) != hash(event3) +class TestNodeFailingOverEvent: + """Test the NodeFailingOverEvent class.""" + + def test_init(self): + """Test NodeFailingOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeFailingOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeFailingOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailingOverEvent.""" + event1 = NodeFailingOverEvent(id=1, ttl=5) + event2 = NodeFailingOverEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeFailingOverEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeFailedOverEvent: + """Test the NodeFailedOverEvent class.""" + + def test_init(self): + """Test NodeFailedOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeFailedOverEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeFailedOverEvent.DEFAULT_TTL == 5 + event = NodeFailedOverEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeFailedOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeFailedOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailedOverEvent.""" + event1 = NodeFailedOverEvent(id=1) + event2 = NodeFailedOverEvent(id=1) # Same id + event3 = NodeFailedOverEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + class TestMaintenanceEventsConfig: """Test the MaintenanceEventsConfig class.""" @@ -477,19 +558,41 @@ def test_handle_event_migrating(self): """Test handling of NodeMigratingEvent.""" event = NodeMigratingEvent(id=1, ttl=5) - with patch.object(self.handler, "handle_migrating_event") as mock_handle: + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) def test_handle_event_migrated(self): """Test handling of NodeMigratedEvent.""" event = NodeMigratedEvent(id=1) with patch.object( - self.handler, "handle_migration_completed_event" + self.handler, "handle_maintenance_completed_event" ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with() + + def test_handle_event_failing_over(self): + """Test handling of NodeFailingOverEvent.""" + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE) + + def test_handle_event_failed_over(self): + """Test handling of NodeFailedOverEvent.""" + event = NodeFailedOverEvent(id=1) + + with patch.object( + self.handler, "handle_maintenance_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with() def test_handle_event_unknown_type(self): """Test handling of unknown event type.""" @@ -500,42 +603,61 @@ def test_handle_event_unknown_type(self): result = self.handler.handle_event(event) assert result is None - def test_handle_migrating_event_disabled(self): - """Test migrating event handling when relax timeouts are disabled.""" + def test_handle_maintenance_start_event_disabled(self): + """Test maintenance start event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratingEvent(id=1, ttl=5) - result = handler.handle_migrating_event(event) + result = handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migrating_event_success(self): - """Test successful migrating event handling.""" - event = NodeMigratingEvent(id=1, ttl=5) + def test_handle_maintenance_start_event_moving_state(self): + """Test maintenance start event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING + + result = self.handler.handle_maintenance_start_event( + MaintenanceState.MAINTENANCE + ) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_maintenance_start_event_success(self): + """Test successful maintenance start event handling for migrating.""" + self.mock_connection.maintenance_state = MaintenanceState.NONE - self.handler.handle_migrating_event(event) + self.handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE) + assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) self.mock_connection.set_tmp_settings.assert_called_once_with( tmp_relax_timeout=20 ) - def test_handle_migration_completed_event_disabled(self): - """Test migration completed event handling when relax timeouts are disabled.""" + def test_handle_maintenance_completed_event_disabled(self): + """Test maintenance completed event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratedEvent(id=1) - result = handler.handle_migration_completed_event(event) + result = handler.handle_maintenance_completed_event() assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migration_completed_event_success(self): - """Test successful migration completed event handling.""" - event = NodeMigratedEvent(id=1) + def test_handle_maintenance_completed_event_moving_state(self): + """Test maintenance completed event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING + + result = self.handler.handle_maintenance_completed_event() + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_maintenance_completed_event_success(self): + """Test successful maintenance completed event handling.""" + self.mock_connection.maintenance_state = MaintenanceState.MAINTENANCE + + self.handler.handle_maintenance_completed_event() - self.handler.handle_migration_completed_event(event) + assert self.mock_connection.maintenance_state == MaintenanceState.NONE self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) self.mock_connection.reset_tmp_settings.assert_called_once_with( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 8db8d182a7..ea0021c8a5 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -16,9 +16,11 @@ from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, + NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventPoolHandler, NodeMovingEvent, - NodeMigratedEvent, ) @@ -189,6 +191,22 @@ def send(self, data): # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" response = migrated_push.encode() + response + elif ( + b"key_receive_failing_over_" in data + or b"key_receive_failing_over" in data + ): + # FAILING_OVER push message before SET key_receive_failing_over_X response + # Format: >2\r\n$12\r\nFAILING_OVER\r\n:10\r\n (2 elements: FAILING_OVER, ttl) + failing_over_push = ">2\r\n$12\r\nFAILING_OVER\r\n:10\r\n" + response = failing_over_push.encode() + response + elif ( + b"key_receive_failed_over_" in data + or b"key_receive_failed_over" in data + ): + # FAILED_OVER push message before SET key_receive_failed_over_X response + # Format: >1\r\n$11\r\nFAILED_OVER\r\n (1 element: FAILED_OVER) + failed_over_push = ">1\r\n$11\r\nFAILED_OVER\r\n" + response = failed_over_push.encode() + response elif b"key_receive_moving_" in data: # MOVING push message before SET key_receive_moving_X response # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) @@ -211,6 +229,10 @@ def send(self, data): self.pending_responses.append(b"$6\r\nvalue2\r\n") elif b"key_receive_migrated" in data: self.pending_responses.append(b"$6\r\nvalue3\r\n") + elif b"key_receive_failing_over" in data: + self.pending_responses.append(b"$6\r\nvalue4\r\n") + elif b"key_receive_failed_over" in data: + self.pending_responses.append(b"$6\r\nvalue5\r\n") elif b"key1" in data: self.pending_responses.append(b"$6\r\nvalue1\r\n") else: @@ -727,13 +749,14 @@ def test_migration_related_events_handling_integration(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migrating_event_with_disabled_relax_timeout(self, pool_class): """ - Test migrating event handling when relax timeout is disabled. + Test maintenance events handling when relax timeout is disabled. This test validates that when relax_timeout is disabled (-1): - 1. MIGRATING events are received and processed + 1. MIGRATING, MIGRATED, FAILING_OVER, and FAILED_OVER events are received and processed 2. No timeout updates are applied to connections - 3. Socket timeouts remain unchanged during migration events + 3. Socket timeouts remain unchanged during all maintenance events 4. Tests both ConnectionPool and BlockingConnectionPool implementations + 5. Tests the complete lifecycle: MIGRATING -> MIGRATED -> FAILING_OVER -> FAILED_OVER """ # Create config with disabled relax timeout disabled_config = MaintenanceEventsConfig( @@ -776,6 +799,57 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" ) + # Command 4: This SET command will receive MIGRATED push message before response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result4 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 4 result + assert result4 is True, "Command 4 (SET key_receive_migrated) failed" + + # Validate timeout is still NOT updated after MIGRATED (relax is disabled) + self._validate_current_timeout(None) + + # Command 5: This SET command will receive FAILING_OVER push message before response + key_failing_over = "key_receive_failing_over" + value_failing_over = "value4" + result5 = test_redis_client.set(key_failing_over, value_failing_over) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_failing_over) failed" + + # Validate timeout is still NOT updated after FAILING_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 6: Another command to verify timeout remains unchanged during failover + result6 = test_redis_client.get(key_failing_over) + + # Validate Command 6 result + expected_value6 = value_failing_over.encode() + assert result6 == expected_value6, ( + f"Command 6 (GET key_receive_failing_over) failed. Expected: {expected_value6}, Got: {result6}" + ) + + # Command 7: This SET command will receive FAILED_OVER push message before response + key_failed_over = "key_receive_failed_over" + value_failed_over = "value5" + result7 = test_redis_client.set(key_failed_over, value_failed_over) + + # Validate Command 7 result + assert result7 is True, "Command 7 (SET key_receive_failed_over) failed" + + # Validate timeout is still NOT updated after FAILED_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 8: Final command to verify timeout remains unchanged after all events + result8 = test_redis_client.get(key_failed_over) + + # Validate Command 8 result + expected_value8 = value_failed_over.encode() + assert result8 == expected_value8, ( + f"Command 8 (GET key_receive_failed_over) failed. Expected: {expected_value8}, Got: {result8}" + ) + # Verify maintenance events were processed correctly # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( @@ -1357,7 +1431,8 @@ def worker(idx): def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): """ Test moving configs are not lost if the per connection events get picked up after moving is handled. - MOVING → MIGRATING → MIGRATED → MOVED + Sequence of events: MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER, MOVED. + Note: FAILING_OVER and FAILED_OVER events do not change the connection state when already in MOVING state. Checks the state after each event for all connections and for new connections created during each state. """ # Setup @@ -1448,7 +1523,45 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - # 4. MOVED event (simulate timer expiry) + # 4. FAILING_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailingOverEvent(id=3, ttl=1) + ) + # State should not change for connections that are in MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 5. FAILED_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailedOverEvent(id=3) + ) + # State should not change for connections that are in MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 6. MOVED event (simulate timer expiry) pool_handler.handle_node_moved_event(moving_event) Helpers.validate_in_use_connections_state( in_use_connections, @@ -1695,7 +1808,7 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): conn_event_handler = conn._maintenance_event_connection_handler conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) # validate connection is in MIGRATING state - assert conn.maintenance_state == MaintenanceState.MIGRATING + assert conn.maintenance_state == MaintenanceState.MAINTENANCE assert conn.socket_timeout == self.config.relax_timeout # Send MIGRATED event to con with ip = key3