Skip to content

Commit

Permalink
AsyncIO Race Condition Fix (#2639)
Browse files Browse the repository at this point in the history
  • Loading branch information
chayim authored Mar 22, 2023
1 parent 54a1dce commit 7b48b1b
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 8 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ jobs:
invoke linters
run-tests:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
timeout-minutes: 30
strategy:
max-parallel: 15
fail-fast: false
matrix:
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7']
test-type: ['standalone', 'cluster']
Expand Down Expand Up @@ -79,8 +80,9 @@ jobs:
install_package_from_commit:
name: Install package from commit hash
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7']
steps:
Expand Down
12 changes: 9 additions & 3 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,10 +1349,16 @@ async def execute(self, raise_on_error: bool = True):
conn = cast(Connection, conn)

try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()

Expand Down
12 changes: 10 additions & 2 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,10 +879,18 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
await connection.send_packed_command(connection.pack_command(*args), False)

# Read response
return await asyncio.shield(
self._parse_and_release(connection, args[0], **kwargs)
)

async def _parse_and_release(self, connection, *args, **kwargs):
try:
return await self.parse_response(connection, args[0], **kwargs)
return await self.parse_response(connection, *args, **kwargs)
except asyncio.CancelledError:
# should not be possible
await connection.disconnect(nowait=True)
raise
finally:
# Release connection
self._free.append(connection)

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
long_description_content_type="text/markdown",
keywords=["Redis", "key-value store", "database"],
license="MIT",
version="4.3.5",
version="4.3.6",
packages=find_packages(
include=[
"redis",
Expand Down
17 changes: 17 additions & 0 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,23 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None:
called_count += 1
assert called_count == 1

async def test_asynckills(self, r) -> None:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

async def test_execute_command_default_node(self, r: RedisCluster) -> None:
"""
Test command execution without node flag is being executed on the
Expand Down
23 changes: 23 additions & 0 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,29 @@ async def test_invalid_response(create_redis):
assert str(cm.value) == f"Protocol Error: {raw!r}"


@pytest.mark.onlynoncluster
async def test_asynckills():
from redis.asyncio.client import Redis

for b in [True, False]:
r = Redis(single_connection_client=b)

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"


@skip_if_server_version_lt("4.0.0")
@pytest.mark.redismod
@pytest.mark.onlynoncluster
Expand Down

0 comments on commit 7b48b1b

Please sign in to comment.