Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions, support Python 3.11 #295

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions async_timeout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __exit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self._do_exit(exc_type)
self._do_exit(exc_type, exc_val)
return None

async def __aenter__(self) -> "Timeout":
Expand All @@ -126,7 +126,7 @@ async def __aexit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self._do_exit(exc_type)
self._do_exit(exc_type, exc_val)
return None

@property
Expand Down Expand Up @@ -206,17 +206,34 @@ def _do_enter(self) -> None:
self._state = _State.ENTER
self._reschedule()

def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
def _do_exit(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
) -> None:
if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT:
self._timeout_handler = None
raise asyncio.TimeoutError
# timeout has not expired
self._state = _State.EXIT
skip = False
if sys.version_info >= (3, 9):
# Analyse msg
assert exc_val is not None
if not exc_val.args or exc_val.args[0] != id(self):
skip = True
if not skip:
if sys.version_info >= (3, 11):
asyncio.current_task().uncancel()
raise asyncio.TimeoutError
Comment on lines +215 to +224
Copy link
Member

@Dreamsorcerer Dreamsorcerer Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following what Guido was saying, I think this needs to check to check the cancelled count, maybe something like:

Suggested change
skip = False
if sys.version_info >= (3, 9):
# Analyse msg
assert exc_val is not None
if not exc_val.args or exc_val.args[0] != id(self):
skip = True
if not skip:
if sys.version_info >= (3, 11):
asyncio.current_task().uncancel()
raise asyncio.TimeoutError
self._timeout_handler = None
if sys.version_info >= (3, 11):
if asyncio.current_task().uncancel() == 0:
self._timeout_handler = None
raise asyncio.TimeoutError()
else:
raise asyncio.TimeoutError()

I think this is all that's needed, as the == _State.TIMEOUT tells us that we initiated a cancel, so there's no need for the cancel message anymore. Would be good to test this out against those previous tests before wrapping up the cpython discussion.

# state is EXIT if not timed out previously
if self._state != _State.TIMEOUT:
self._state = _State.EXIT
self._reject()
return None

def _on_timeout(self, task: "asyncio.Task[None]") -> None:
task.cancel()
# Note: the second '.cancel()' call is ignored on Python 3.11
if sys.version_info >= (3, 9):
task.cancel(id(self))
else:
task.cancel()
self._state = _State.TIMEOUT
# drop the reference early
self._timeout_handler = None
37 changes: 37 additions & 0 deletions tests/test_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,40 @@ async def test_deprecated_with() -> None:
with pytest.warns(DeprecationWarning):
with timeout(1):
await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_double_timeouts() -> None:
with pytest.raises(asyncio.TimeoutError):
async with timeout(0.1) as cm1:
async with timeout(0.1) as cm2:
await asyncio.sleep(10)

assert cm1.expired
assert cm2.expired


@pytest.mark.asyncio
async def test_timeout_with_cancelled_task() -> None:

event = asyncio.Event()

async def coro() -> None:
event.set()
async with timeout_cm:
await asyncio.sleep(5)

async def main() -> str:
task = asyncio.create_task(coro())
await event.wait()
loop = asyncio.get_running_loop()
timeout_cm.update(loop.time()) # reschedule to the next loop iteration
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
return "ok"

timeout_cm = timeout(3600) # reschedule just before the usage
task2 = asyncio.create_task(main())
assert "ok" == await task2
assert timeout_cm.expired
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert timeout_cm.expired
assert timeout_cm.expired
async def race_condition(offset: float = 0) -> List[str]:
"""Common code for below race condition tests."""
async def test_task(deadline: float, loop: asyncio.AbstractEventLoop) -> None:
# We need the internal Timeout class to specify the deadline (not delay).
# This is needed to create the precise timing to reproduce the race condition.
with pytest.warns(DeprecationWarning):
with Timeout(deadline, loop):
await asyncio.sleep(10)
call_order: List[str] = []
f_exit = log_func(Timeout._do_exit, "exit", call_order)
Timeout._do_exit = f_exit # type: ignore[assignment]
f_timeout = log_func(Timeout._on_timeout, "timeout", call_order)
Timeout._on_timeout = f_timeout # type: ignore[assignment]
loop = asyncio.get_running_loop()
deadline = loop.time() + 1
t = asyncio.create_task(test_task(deadline, loop))
loop.call_at(deadline + offset, log_func(t.cancel, "cancel", call_order))
# If we get a TimeoutError, then the code is broken.
with pytest.raises(asyncio.CancelledError):
await t
return call_order
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Only fixed in 3.11+")
@pytest.mark.asyncio
async def test_race_condition_cancel_before() -> None:
"""Test race condition when cancelling before timeout.
If cancel happens immediately before the timeout, then
the timeout may overrule the cancellation, making it
impossible to cancel some tasks.
"""
call_order = await race_condition()
# This test is very timing dependant, so we check the order that calls
# happened to be sure the test itself ran correctly.
assert call_order == ["cancel", "timeout", "exit"]
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Only fixed 3.11+.")
@pytest.mark.asyncio
async def test_race_condition_cancel_after() -> None:
"""Test race condition when cancelling after timeout.
Similarly to the previous test, if a cancel happens
immediately after the timeout (but before the __exit__),
then the explicit cancel can get overruled again.
"""
call_order = await race_condition(0.000001)
# This test is very timing dependant, so we check the order that calls
# happened to be sure the test itself ran correctly.
assert call_order == ["timeout", "cancel", "exit"]