diff --git a/starlette/requests.py b/starlette/requests.py index e3c91e284..649e4cb94 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -205,10 +205,33 @@ def method(self) -> str: def receive(self) -> Receive: return self._receive + @property + def _stream(self) -> typing.List[bytes]: + try: + return self.scope["extensions"]["starlette"]["stream"] + except KeyError: + raise AttributeError("_stream") + + @_stream.setter + def _stream(self, chunks: typing.List[bytes]) -> None: + if "starlette" not in self.scope["extensions"]: + self.scope["extensions"]["starlette"] = {} + self.scope["extensions"]["starlette"]["stream"] = chunks + + async def _cache_stream(self) -> typing.AsyncGenerator[bytes, None]: + _stream: typing.List[bytes] = [] + + async for chunk in self.stream(): + _stream.append(chunk) + yield chunk + + if not hasattr(self, "_stream"): + self._stream = _stream + async def stream(self) -> typing.AsyncGenerator[bytes, None]: - if hasattr(self, "_body"): - yield self._body - yield b"" + if hasattr(self, "_stream"): + for chunk in self._stream: + yield chunk return if self._stream_consumed: @@ -229,18 +252,14 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: yield b"" async def body(self) -> bytes: - if not hasattr(self, "_body"): - chunks = [] - async for chunk in self.stream(): - chunks.append(chunk) - self._body = b"".join(chunks) - return self._body + chunks = [] + async for chunk in self._cache_stream(): + chunks.append(chunk) + return b"".join(chunks) async def json(self) -> typing.Any: - if not hasattr(self, "_json"): - body = await self.body() - self._json = json.loads(body) - return self._json + body = await self.body() + return json.loads(body) async def form(self) -> FormData: if not hasattr(self, "_form"): diff --git a/tests/test_requests.py b/tests/test_requests.py index 799e61f80..36cf75b42 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -158,6 +158,21 @@ async def app(scope, receive, send): assert response.json() == {"body": "", "stream": "abc"} +def test_request_body_then_request_body(test_client_factory): + async def app(scope, receive, send): + request = Request(scope, receive) + body = await request.body() + request2 = Request(scope, request.receive) + body2 = await request2.body() + response = JSONResponse({"body": body.decode(), "body2": body2.decode()}) + await response(scope, receive, send) + + client = test_client_factory(app) + + response = client.post("/", data="abc") + assert response.json() == {"body": "abc", "body2": "abc"} + + def test_request_json(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive)