Skip to content

Commit cb94533

Browse files
Ensure writer is always reset on completion (#7815) (#7826)
(cherry picked from commit 8f2f048)
1 parent c0f9017 commit cb94533

File tree

5 files changed

+79
-38
lines changed

5 files changed

+79
-38
lines changed

Diff for: CHANGES/7815.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer`

Diff for: aiohttp/client_reqrep.py

+49-25
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@
5353
reify,
5454
set_result,
5555
)
56-
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
56+
from .http import (
57+
SERVER_SOFTWARE,
58+
HttpVersion,
59+
HttpVersion10,
60+
HttpVersion11,
61+
StreamWriter,
62+
)
5763
from .log import client_logger
5864
from .streams import StreamReader
5965
from .typedefs import (
@@ -241,7 +247,7 @@ class ClientRequest:
241247
auth = None
242248
response = None
243249

244-
_writer = None # async task for streaming data
250+
__writer = None # async task for streaming data
245251
_continue = None # waiter future for '100 Continue' response
246252

247253
# N.B.
@@ -332,6 +338,21 @@ def __init__(
332338
traces = []
333339
self._traces = traces
334340

341+
def __reset_writer(self, _: object = None) -> None:
342+
self.__writer = None
343+
344+
@property
345+
def _writer(self) -> Optional["asyncio.Task[None]"]:
346+
return self.__writer
347+
348+
@_writer.setter
349+
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
350+
if self.__writer is not None:
351+
self.__writer.remove_done_callback(self.__reset_writer)
352+
self.__writer = writer
353+
if writer is not None:
354+
writer.add_done_callback(self.__reset_writer)
355+
335356
def is_ssl(self) -> bool:
336357
return self.url.scheme in ("https", "wss")
337358

@@ -625,8 +646,6 @@ async def write_bytes(
625646
else:
626647
await writer.write_eof()
627648
protocol.start_timeout()
628-
finally:
629-
self._writer = None
630649

631650
async def send(self, conn: "Connection") -> "ClientResponse":
632651
# Specify request target:
@@ -711,16 +730,14 @@ async def send(self, conn: "Connection") -> "ClientResponse":
711730

712731
async def close(self) -> None:
713732
if self._writer is not None:
714-
try:
715-
with contextlib.suppress(asyncio.CancelledError):
716-
await self._writer
717-
finally:
718-
self._writer = None
733+
with contextlib.suppress(asyncio.CancelledError):
734+
await self._writer
719735

720736
def terminate(self) -> None:
721737
if self._writer is not None:
722738
if not self.loop.is_closed():
723739
self._writer.cancel()
740+
self._writer.remove_done_callback(self.__reset_writer)
724741
self._writer = None
725742

726743
async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
@@ -740,9 +757,9 @@ class ClientResponse(HeadersMixin):
740757
# but will be set by the start() method.
741758
# As the end user will likely never see the None values, we cheat the types below.
742759
# from the Status-Line of the response
743-
version = None # HTTP-Version
744-
status: int = None # type: ignore[assignment] # Status-Code
745-
reason = None # Reason-Phrase
760+
version: Optional[HttpVersion] = None # HTTP-Version
761+
status: int = None # type: ignore[assignment] # Status-Code
762+
reason: Optional[str] = None # Reason-Phrase
746763

747764
content: StreamReader = None # type: ignore[assignment] # Payload stream
748765
_headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
@@ -754,6 +771,7 @@ class ClientResponse(HeadersMixin):
754771
# post-init stage allows to not change ctor signature
755772
_closed = True # to allow __del__ for non-initialized properly response
756773
_released = False
774+
__writer = None
757775

758776
def __init__(
759777
self,
@@ -799,6 +817,21 @@ def __init__(
799817
if loop.get_debug():
800818
self._source_traceback = traceback.extract_stack(sys._getframe(1))
801819

820+
def __reset_writer(self, _: object = None) -> None:
821+
self.__writer = None
822+
823+
@property
824+
def _writer(self) -> Optional["asyncio.Task[None]"]:
825+
return self.__writer
826+
827+
@_writer.setter
828+
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
829+
if self.__writer is not None:
830+
self.__writer.remove_done_callback(self.__reset_writer)
831+
self.__writer = writer
832+
if writer is not None:
833+
writer.add_done_callback(self.__reset_writer)
834+
802835
@reify
803836
def url(self) -> URL:
804837
return self._url
@@ -863,7 +896,7 @@ def __repr__(self) -> str:
863896
"ascii", "backslashreplace"
864897
).decode("ascii")
865898
else:
866-
ascii_encodable_reason = self.reason
899+
ascii_encodable_reason = "None"
867900
print(
868901
"<ClientResponse({}) [{} {}]>".format(
869902
ascii_encodable_url, self.status, ascii_encodable_reason
@@ -1044,18 +1077,12 @@ def _release_connection(self) -> None:
10441077

10451078
async def _wait_released(self) -> None:
10461079
if self._writer is not None:
1047-
try:
1048-
await self._writer
1049-
finally:
1050-
self._writer = None
1080+
await self._writer
10511081
self._release_connection()
10521082

10531083
def _cleanup_writer(self) -> None:
10541084
if self._writer is not None:
1055-
if self._writer.done():
1056-
self._writer = None
1057-
else:
1058-
self._writer.cancel()
1085+
self._writer.cancel()
10591086
self._session = None
10601087

10611088
def _notify_content(self) -> None:
@@ -1066,10 +1093,7 @@ def _notify_content(self) -> None:
10661093

10671094
async def wait_for_close(self) -> None:
10681095
if self._writer is not None:
1069-
try:
1070-
await self._writer
1071-
finally:
1072-
self._writer = None
1096+
await self._writer
10731097
self.release()
10741098

10751099
async def read(self) -> bytes:

Diff for: tests/test_client_request.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import urllib.parse
66
import zlib
77
from http.cookies import BaseCookie, Morsel, SimpleCookie
8-
from typing import Any, Dict, Optional
8+
from typing import Any, Callable, Dict, Optional
99
from unittest import mock
1010

1111
import pytest
@@ -24,6 +24,17 @@
2424
from aiohttp.test_utils import make_mocked_coro
2525

2626

27+
class WriterMock(mock.AsyncMock):
28+
def __await__(self) -> None:
29+
return self().__await__()
30+
31+
def add_done_callback(self, cb: Callable[[], None]) -> None:
32+
"""Dummy method."""
33+
34+
def remove_done_callback(self, cb: Callable[[], None]) -> None:
35+
"""Dummy method."""
36+
37+
2738
@pytest.fixture
2839
def make_request(loop):
2940
request = None
@@ -1167,7 +1178,7 @@ def read(self, decode=False):
11671178
async def test_oserror_on_write_bytes(loop, conn) -> None:
11681179
req = ClientRequest("POST", URL("http://python.org/"), loop=loop)
11691180

1170-
writer = mock.Mock()
1181+
writer = WriterMock()
11711182
writer.write.side_effect = OSError
11721183

11731184
await req.write_bytes(writer, conn)
@@ -1183,7 +1194,8 @@ async def test_terminate(loop, conn) -> None:
11831194
req = ClientRequest("get", URL("http://python.org"), loop=loop)
11841195
resp = await req.send(conn)
11851196
assert req._writer is not None
1186-
writer = req._writer = mock.Mock()
1197+
writer = req._writer = WriterMock()
1198+
writer.cancel = mock.Mock()
11871199

11881200
req.terminate()
11891201
assert req._writer is None
@@ -1201,7 +1213,7 @@ async def go():
12011213
req = ClientRequest("get", URL("http://python.org"))
12021214
resp = await req.send(conn)
12031215
assert req._writer is not None
1204-
writer = req._writer = mock.Mock()
1216+
writer = req._writer = WriterMock()
12051217

12061218
await asyncio.sleep(0.05)
12071219

Diff for: tests/test_client_response.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gc
44
import sys
5+
from typing import Callable
56
from unittest import mock
67

78
import pytest
@@ -19,6 +20,9 @@ class WriterMock(mock.AsyncMock):
1920
def __await__(self) -> None:
2021
return self().__await__()
2122

23+
def add_done_callback(self, cb: Callable[[], None]) -> None:
24+
cb()
25+
2226
def done(self) -> bool:
2327
return True
2428

Diff for: tests/test_proxy.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
202202
"get",
203203
URL("http://proxy.example.com"),
204204
request_info=mock.Mock(),
205-
writer=mock.Mock(),
205+
writer=None,
206206
continue100=None,
207207
timer=TimerNoop(),
208208
traces=[],
@@ -264,7 +264,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
264264
"get",
265265
URL("http://proxy.example.com"),
266266
request_info=mock.Mock(),
267-
writer=mock.Mock(),
267+
writer=None,
268268
continue100=None,
269269
timer=TimerNoop(),
270270
traces=[],
@@ -326,7 +326,7 @@ def test_https_connect(self, ClientRequestMock) -> None:
326326
"get",
327327
URL("http://proxy.example.com"),
328328
request_info=mock.Mock(),
329-
writer=mock.Mock(),
329+
writer=None,
330330
continue100=None,
331331
timer=TimerNoop(),
332332
traces=[],
@@ -386,7 +386,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock) -> None:
386386
"get",
387387
URL("http://proxy.example.com"),
388388
request_info=mock.Mock(),
389-
writer=mock.Mock(),
389+
writer=None,
390390
continue100=None,
391391
timer=TimerNoop(),
392392
traces=[],
@@ -440,7 +440,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock) -> None:
440440
"get",
441441
URL("http://proxy.example.com"),
442442
request_info=mock.Mock(),
443-
writer=mock.Mock(),
443+
writer=None,
444444
continue100=None,
445445
timer=TimerNoop(),
446446
traces=[],
@@ -496,7 +496,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock) -> None:
496496
"get",
497497
URL("http://proxy.example.com"),
498498
request_info=mock.Mock(),
499-
writer=mock.Mock(),
499+
writer=None,
500500
continue100=None,
501501
timer=TimerNoop(),
502502
traces=[],
@@ -555,7 +555,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock) -> None:
555555
"get",
556556
URL("http://proxy.example.com"),
557557
request_info=mock.Mock(),
558-
writer=mock.Mock(),
558+
writer=None,
559559
continue100=None,
560560
timer=TimerNoop(),
561561
traces=[],
@@ -666,7 +666,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock) -> None:
666666
"get",
667667
URL("http://proxy.example.com"),
668668
request_info=mock.Mock(),
669-
writer=mock.Mock(),
669+
writer=None,
670670
continue100=None,
671671
timer=TimerNoop(),
672672
traces=[],
@@ -737,7 +737,7 @@ def test_https_auth(self, ClientRequestMock) -> None:
737737
"get",
738738
URL("http://proxy.example.com"),
739739
request_info=mock.Mock(),
740-
writer=mock.Mock(),
740+
writer=None,
741741
continue100=None,
742742
timer=TimerNoop(),
743743
traces=[],

0 commit comments

Comments
 (0)