diff --git a/docs/extensions.md b/docs/extensions.md index e7bb9a7d..f8176e83 100644 --- a/docs/extensions.md +++ b/docs/extensions.md @@ -275,3 +275,23 @@ with httpcore.stream("GET", "https://www.example.com") as response: ssl_object = network_stream.get_extra_info("ssl_object") print("TLS version", ssl_object.version()) ``` + +### `"trailing_headers"` + +Trailing headers are a rarely used feature of HTTP, where supplementary headers may be sent at the end of the response data. + +The `trailing_headers` response extenstion is implemented as a list of `(byte, byte)` tuples containing any [trailing headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Trailer#chunked_transfer_encoding_using_a_trailing_header) sent at the end of the response. This list is only populated once the response is complete, and will be empty while streaming the response data. + +```python +# The "TE: trailers" header should be used in order to indicate that we're +# willing to accept trailing headers. This isn't required by the `httpcore` +# package itself, but is mandated by the HTTP spec, and might be required +# by some servers or proxies. +response = httpcore.request("GET", "https://www.example.com", headers={"TE": "trailers"}) + +# Show the standard response headers. +print(response.headers) + +# Show any trailing headers sent at the end of the response. +print(response.extensions['trailing_headers']) +``` diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index de965699..7aa9d707 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -96,14 +96,16 @@ async def handle_async_request(self, request: Request) -> Response: headers, ) + trailing_headers: List[Tuple[bytes, bytes]] = [] return Response( status=status, headers=headers, - content=HTTP11ConnectionByteStream(self, request), + content=HTTP11ConnectionByteStream(self, request, trailing_headers), extensions={ "http_version": http_version, "reason_phrase": reason_phrase, "network_stream": self._network_stream, + "trailing_headers": trailing_headers, }, ) except BaseException as exc: @@ -164,15 +166,28 @@ async def _receive_response_headers( return http_version, event.status_code, event.reason, headers - async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]: + async def _receive_response_body( + self, request: Request, trailing_headers: List[Tuple[bytes, bytes]] + ) -> AsyncIterator[bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) while True: event = await self._receive_event(timeout=timeout) if isinstance(event, h11.Data): + # Each response will have zero, one, or more data events, + # containing the body of the response. yield bytes(event.data) - elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + elif isinstance(event, h11.EndOfMessage): + # Once we get an EndOfMessage event, the response data has finished. + if event.headers: + trailing_headers.extend(event.headers.raw_items()) + break + elif isinstance(event, h11.PAUSED): + # This can occur here on a successful CONNECT or Upgrade + # response, where it is returned rather than EndOfMessage. + # + # See https://h11.readthedocs.io/en/latest/api.html#flow-control break async def _receive_event( @@ -291,16 +306,24 @@ async def __aexit__( class HTTP11ConnectionByteStream: - def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: + def __init__( + self, + connection: AsyncHTTP11Connection, + request: Request, + trailing_headers: List[Tuple[bytes, bytes]], + ) -> None: self._connection = connection self._request = request + self._trailing_headers = trailing_headers self._closed = False async def __aiter__(self) -> AsyncIterator[bytes]: - kwargs = {"request": self._request} + kwargs = {"request": self._request, "trailing_headers": self._trailing_headers} try: async with Trace("http11.receive_response_body", self._request, kwargs): - async for chunk in self._connection._receive_response_body(**kwargs): + async for chunk in self._connection._receive_response_body( + request=self._request, trailing_headers=self._trailing_headers + ): yield chunk except BaseException as exc: # If we get an exception while streaming the response, diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index 26f7d0cb..3cf5ac63 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -96,14 +96,16 @@ def handle_request(self, request: Request) -> Response: headers, ) + trailing_headers: List[Tuple[bytes, bytes]] = [] return Response( status=status, headers=headers, - content=HTTP11ConnectionByteStream(self, request), + content=HTTP11ConnectionByteStream(self, request, trailing_headers), extensions={ "http_version": http_version, "reason_phrase": reason_phrase, "network_stream": self._network_stream, + "trailing_headers": trailing_headers, }, ) except BaseException as exc: @@ -164,15 +166,28 @@ def _receive_response_headers( return http_version, event.status_code, event.reason, headers - def _receive_response_body(self, request: Request) -> Iterator[bytes]: + def _receive_response_body( + self, request: Request, trailing_headers: List[Tuple[bytes, bytes]] + ) -> Iterator[bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) while True: event = self._receive_event(timeout=timeout) if isinstance(event, h11.Data): + # Each response will have zero, one, or more data events, + # containing the body of the response. yield bytes(event.data) - elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + elif isinstance(event, h11.EndOfMessage): + # Once we get an EndOfMessage event, the response data has finished. + if event.headers: + trailing_headers.extend(event.headers.raw_items()) + break + elif isinstance(event, h11.PAUSED): + # This can occur here on a successful CONNECT or Upgrade + # response, where it is returned rather than EndOfMessage. + # + # See https://h11.readthedocs.io/en/latest/api.html#flow-control break def _receive_event( @@ -291,16 +306,24 @@ def __exit__( class HTTP11ConnectionByteStream: - def __init__(self, connection: HTTP11Connection, request: Request) -> None: + def __init__( + self, + connection: HTTP11Connection, + request: Request, + trailing_headers: List[Tuple[bytes, bytes]], + ) -> None: self._connection = connection self._request = request + self._trailing_headers = trailing_headers self._closed = False def __iter__(self) -> Iterator[bytes]: - kwargs = {"request": self._request} + kwargs = {"request": self._request, "trailing_headers": self._trailing_headers} try: with Trace("http11.receive_response_body", self._request, kwargs): - for chunk in self._connection._receive_response_body(**kwargs): + for chunk in self._connection._receive_response_body( + request=self._request, trailing_headers=self._trailing_headers + ): yield chunk except BaseException as exc: # If we get an exception while streaming the response, diff --git a/tests/_async/test_http11.py b/tests/_async/test_http11.py index 3a5ea54d..6bbeedb7 100644 --- a/tests/_async/test_http11.py +++ b/tests/_async/test_http11.py @@ -39,6 +39,80 @@ async def test_http11_connection(): ) +@pytest.mark.anyio +async def test_http11_connection_chunked_response(): + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Transfer-Encoding: chunked\r\n", + b"\r\n", + b"3\r\n", + b"Hel\r\n", + b"4\r\n", + b"lo, \r\n", + b"6\r\n", + b"world!\r\n", + b"0\r\n", + b"\r\n", + ] + ) + async with AsyncHTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert not conn.is_closed() + assert conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + +@pytest.mark.anyio +async def test_http11_connection_trailing_headers_response(): + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Transfer-Encoding: chunked\r\n", + b"Trailer: Surprise\r\n", + b"\r\n", + b"3\r\n", + b"Hel\r\n", + b"4\r\n", + b"lo, \r\n", + b"6\r\n", + b"world!\r\n", + b"0\r\n", + b"Surprise: You thought we were done here?\r\n", + b"\r\n", + ] + ) + async with AsyncHTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = await conn.request( + "GET", "https://example.com/", headers={"TE": "trailers"} + ) + assert response.status == 200 + assert response.content == b"Hello, world!" + assert response.headers == [ + (b"Content-Type", b"plain/text"), + (b"Transfer-Encoding", b"chunked"), + (b"Trailer", b"Surprise"), + ] + trailing_headers = response.extensions["trailing_headers"] + assert trailing_headers == [(b"Surprise", b"You thought we were done here?")] + + @pytest.mark.anyio async def test_http11_connection_unread_response(): """ diff --git a/tests/_sync/test_http11.py b/tests/_sync/test_http11.py index dd26f0c4..0e0b1a9a 100644 --- a/tests/_sync/test_http11.py +++ b/tests/_sync/test_http11.py @@ -40,6 +40,80 @@ def test_http11_connection(): +def test_http11_connection_chunked_response(): + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Transfer-Encoding: chunked\r\n", + b"\r\n", + b"3\r\n", + b"Hel\r\n", + b"4\r\n", + b"lo, \r\n", + b"6\r\n", + b"world!\r\n", + b"0\r\n", + b"\r\n", + ] + ) + with HTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert not conn.is_closed() + assert conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + + +def test_http11_connection_trailing_headers_response(): + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Transfer-Encoding: chunked\r\n", + b"Trailer: Surprise\r\n", + b"\r\n", + b"3\r\n", + b"Hel\r\n", + b"4\r\n", + b"lo, \r\n", + b"6\r\n", + b"world!\r\n", + b"0\r\n", + b"Surprise: You thought we were done here?\r\n", + b"\r\n", + ] + ) + with HTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = conn.request( + "GET", "https://example.com/", headers={"TE": "trailers"} + ) + assert response.status == 200 + assert response.content == b"Hello, world!" + assert response.headers == [ + (b"Content-Type", b"plain/text"), + (b"Transfer-Encoding", b"chunked"), + (b"Trailer", b"Surprise"), + ] + trailing_headers = response.extensions["trailing_headers"] + assert trailing_headers == [(b"Surprise", b"You thought we were done here?")] + + + def test_http11_connection_unread_response(): """ If the client releases the response without reading it to termination,