diff --git a/starlette/requests.py b/starlette/requests.py index 726abddcc..db4f1af2d 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -194,9 +194,32 @@ def __init__( assert scope["type"] == "http" self._receive = receive self._send = send - self._stream_consumed = False self._is_disconnected = False + # Not particularly graceful, but we store state around reading the request + # body in the ASGI scope, under the following... + # + # ['extensions']['starlette']['body'] + # ['extensions']['starlette']['stream_consumed'] + # + # This allows usages such as ASGI middleware to call the recieve and + # access the request body, and have that state persisted. + # + # Bit of an abuse of ASGI to take this approach. An alternate take would be + # that if you're going to use ASGI middleware it might be better to just + # accept the constraint that you *don't* get access to the request body in + # that context. + def _get_request_state(self, name: str, default: typing.Any = None) -> typing.Any: + return self.scope.get("extensions", {}).get("starlette", {}).get(name, default) + + def _set_request_state(self, name: str, value: typing.Any) -> None: + if "extensions" not in self.scope: + self.scope["extensions"] = {"starlette": {name: value}} + elif "starlette" not in self.scope["extensions"]: + self.scope["extensions"]["starlette"] = {name: value} + else: + self.scope["extensions"]["starlette"][name] = value + @property def method(self) -> str: return self.scope["method"] @@ -206,15 +229,17 @@ def receive(self) -> Receive: return self._receive async def stream(self) -> typing.AsyncGenerator[bytes, None]: - if hasattr(self, "_body"): - yield self._body + body = self._get_request_state("body") + if body is not None: + yield body yield b"" return - if self._stream_consumed: + stream_consumed = self._get_request_state("stream_consumed", default=False) + if stream_consumed: raise RuntimeError("Stream consumed") - self._stream_consumed = True + self._set_request_state("stream_consumed", True) while True: message = await self._receive() if message["type"] == "http.request": @@ -229,12 +254,14 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: yield b"" async def body(self) -> bytes: - if not hasattr(self, "_body"): + body = self._get_request_state("body") + if body is None: chunks: "typing.List[bytes]" = [] async for chunk in self.stream(): chunks.append(chunk) - self._body = b"".join(chunks) - return self._body + body = b"".join(chunks) + self._set_request_state("body", body) + return body async def json(self) -> typing.Any: if not hasattr(self, "_json"): diff --git a/tests/test_requests.py b/tests/test_requests.py index 033df1e6a..b222fae11 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -493,3 +493,65 @@ async def app(scope, receive, send): client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "Send channel not available"} + + +def test_request_body_then_request_body(test_client_factory): + # If the request body is read, then ensure that instantiating a + # request a second time can return the content again. + async def app(scope, receive, send): + request = Request(scope, receive) + body = await request.body() + request2 = Request(scope, receive) + body2 = await request2.body() + response = JSONResponse({"body": body.decode(), "body2": body2.decode()}) + await response(scope, receive, send) + + client = test_client_factory(app) + + response = client.post("/", data="abc") + assert response.json() == {"body": "abc", "body2": "abc"} + + +def test_request_stream_then_request_body(test_client_factory): + # If the request has been streamed, then ensure that instantiating a + # request a second time raises an exception when attempting to read content. + async def app(scope, receive, send): + request = Request(scope, receive) + chunks = b"" + async for chunk in request.stream(): + chunks += chunk + + request2 = Request(scope, receive) + try: + body = await request2.body() + except RuntimeError: + body = b"" + + response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) + await response(scope, receive, send) + + client = test_client_factory(app) + + response = client.post("/", data="abc") + assert response.json() == {"body": "", "stream": "abc"} + + +def test_request_body_then_request_stream(test_client_factory): + # If the request body is read, then ensure that instantiating a + # request a second time can stream the content. + async def app(scope, receive, send): + request = Request(scope, receive) + body = await request.body() + + request2 = Request(scope, receive) + chunks = b"" + async for chunk in request2.stream(): + chunks += chunk + + response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) + await response(scope, receive, send) + + client = test_client_factory(app) + + response = client.post("/", data="abc") + assert response.json() == {"body": "abc", "stream": "abc"}