Skip to content

Commit

Permalink
Fix behaviour of async PythonParser to match RedisParser as for issue #…
Browse files Browse the repository at this point in the history
…2349 (#2582)

* Allow data to drain from PythonParser after connection close.

* Add Changes
  • Loading branch information
kristjanvalur authored Mar 16, 2023
1 parent 7d474f9 commit 1b2f408
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Allow data to drain from async PythonParser when reading during a disconnect()
* Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602)
* Add test and fix async HiredisParser when reading during a disconnect() (#2349)
* Use hiredis-py pack_command if available.
Expand Down
24 changes: 11 additions & 13 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def decode(self, value: EncodableT, force=False) -> EncodableT:
class BaseParser:
"""Plain Python parsing class"""

__slots__ = "_stream", "_read_size"
__slots__ = "_stream", "_read_size", "_connected"

EXCEPTION_CLASSES: ExceptionMappingT = {
"ERR": {
Expand Down Expand Up @@ -177,6 +177,7 @@ class BaseParser:
def __init__(self, socket_read_size: int):
self._stream: Optional[asyncio.StreamReader] = None
self._read_size = socket_read_size
self._connected = False

def __del__(self):
try:
Expand Down Expand Up @@ -213,7 +214,7 @@ async def read_response(
class PythonParser(BaseParser):
"""Plain Python parsing class"""

__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
__slots__ = ("encoder", "_buffer", "_pos", "_chunks")

def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
Expand All @@ -231,28 +232,28 @@ def on_connect(self, connection: "Connection"):
self._stream = connection._reader
if self._stream is None:
raise RedisError("Buffer is closed.")

self.encoder = connection.encoder
self._clear()
self._connected = True

def on_disconnect(self):
"""Called when the stream disconnects"""
if self._stream is not None:
self._stream = None
self.encoder = None
self._clear()
self._connected = False

async def can_read_destructive(self) -> bool:
if not self._connected:
raise RedisError("Buffer is closed.")
if self._buffer:
return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
async with async_timeout(0):
return await self._stream.read(1)
except asyncio.TimeoutError:
return False

async def read_response(self, disable_decoding: bool = False):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
Expand All @@ -266,8 +267,6 @@ async def read_response(self, disable_decoding: bool = False):
async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
Expand Down Expand Up @@ -354,14 +353,13 @@ async def _readline(self) -> bytes:
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""

__slots__ = BaseParser.__slots__ + ("_reader", "_connected")
__slots__ = ("_reader",)

def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader: Optional[hiredis.Reader] = None
self._connected: bool = False

def on_connect(self, connection: "Connection"):
self._stream = connection._reader
Expand Down
2 changes: 0 additions & 2 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ async def test_connection_disconect_race(parser_class):
This test verifies that a read in progress can finish even
if the `disconnect()` method is called.
"""
if parser_class == PythonParser:
pytest.xfail("doesn't work yet with PythonParser")
if parser_class == HiredisParser and not HIREDIS_AVAILABLE:
pytest.skip("Hiredis not available")

Expand Down

0 comments on commit 1b2f408

Please sign in to comment.