From d088089c6ca28e754865210b202f3c86dd544547 Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Tue, 12 May 2020 14:42:03 -0500 Subject: [PATCH 01/10] moved Request _form, _body, and _json to scope --- .gitignore | 1 + starlette/requests.py | 33 +++++++++++++++------------------ tests/test_requests.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 7b5d4318c..f718a3db8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ test.db .coverage .pytest_cache/ +.python-version .mypy_cache/ starlette.egg-info/ venv/ diff --git a/starlette/requests.py b/starlette/requests.py index 56a0c5a9a..26cf861f5 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -191,8 +191,8 @@ def receive(self) -> Receive: return self._receive async def stream(self) -> typing.AsyncGenerator[bytes, None]: - if hasattr(self, "_body"): - yield self._body + if "body" in self.scope: + yield self.scope["body"] yield b"" return @@ -214,21 +214,18 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: yield b"" async def body(self) -> bytes: - if not hasattr(self, "_body"): - chunks = [] - async for chunk in self.stream(): - chunks.append(chunk) - self._body = b"".join(chunks) - return self._body + if "body" not in self.scope: + self.scope["body"] = b"".join([chunk async for chunk in self.stream()]) + return self.scope["body"] async def json(self) -> typing.Any: - if not hasattr(self, "_json"): + if "json" not in self.scope: body = await self.body() - self._json = json.loads(body) - return self._json + self.scope["json"] = json.loads(body) + return self.scope["json"] async def form(self) -> FormData: - if not hasattr(self, "_form"): + if "form" not in self.scope: assert ( parse_options_header is not None ), "The `python-multipart` library must be installed to use form parsing." @@ -236,17 +233,17 @@ async def form(self) -> FormData: content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": multipart_parser = MultiPartParser(self.headers, self.stream()) - self._form = await multipart_parser.parse() + self.scope["form"] = await multipart_parser.parse() elif content_type == b"application/x-www-form-urlencoded": form_parser = FormParser(self.headers, self.stream()) - self._form = await form_parser.parse() + self.scope["form"] = await form_parser.parse() else: - self._form = FormData() - return self._form + self.scope["form"] = FormData() + return self.scope["form"] async def close(self) -> None: - if hasattr(self, "_form"): - await self._form.close() + if "form" in self.scope: + await self.scope["form"].close() async def is_disconnected(self) -> bool: if not self._is_disconnected: diff --git a/tests/test_requests.py b/tests/test_requests.py index c5e50edbe..a5a12c8c1 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -155,6 +155,36 @@ async def app(scope, receive, send): assert response.json() == {"body": "", "stream": "abc"} +def test_request_body_then_request_body(): + async def app(scope, receive, send): + request = Request(scope, receive) + body = await request.body() + request2 = Request(scope, request.receive) + body2 = await request2.body() + response = JSONResponse({"body": body.decode(), "body2": body2.decode()}) + await response(scope, receive, send) + + client = TestClient(app) + + response = client.post("/", data="abc") + assert response.json() == {"body": "abc", "body2": "abc"} + + +def test_request_form_then_request_form(): + async def app(scope, receive, send): + request = Request(scope, receive) + form = await request.form() + request2 = Request(scope, request.receive) + form2 = await request2.form() + response = JSONResponse({"form": dict(form), "form2": dict(form2)}) + await response(scope, receive, send) + + client = TestClient(app) + + response = client.post("/", data={"abc": "123 @"}) + assert response.json() == {"form": {"abc": "123 @"}, "form2": {"abc": "123 @"}} + + def test_request_json(): async def app(scope, receive, send): request = Request(scope, receive) From 41a07ee8414da4dccbd41fa1b56bd17a674e8c73 Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Sun, 30 Jan 2022 20:36:50 -0600 Subject: [PATCH 02/10] removed json scope. moved cached body and form to properties to clean up direct references to scope --- starlette/requests.py | 60 +++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index d1145ca33..c3e1cd5f7 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -202,9 +202,24 @@ def method(self) -> str: def receive(self) -> Receive: return self._receive + @property + def _body(self) -> bytes: + try: + return self.scope["body"] + except KeyError: + raise AttributeError("_body") + + @_body.setter + def _body(self, bytes): + self.scope["body"] = bytes + + @_body.deleter + def _body(self): + del self.scope["body"] + async def stream(self) -> typing.AsyncGenerator[bytes, None]: - if "body" in self.scope: - yield self.scope["body"] + if hasattr(self, "_body"): + yield self._body yield b"" return @@ -226,18 +241,31 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: yield b"" async def body(self) -> bytes: - if "body" not in self.scope: - self.scope["body"] = b"".join([chunk async for chunk in self.stream()]) - return self.scope["body"] + if not hasattr(self, "_body"): + self._body = b"".join([chunk async for chunk in self.stream()]) + return self._body async def json(self) -> typing.Any: - if "json" not in self.scope: - body = await self.body() - self.scope["json"] = json.loads(body) - return self.scope["json"] + body = await self.body() + return json.loads(body) + + @property + def _form(self) -> bytes: + try: + return self.scope["form"] + except KeyError: + raise AttributeError("_form") + + @_form.setter + def _form(self, bytes): + self.scope["form"] = bytes + + @_form.deleter + def _form(self): + del self.scope["form"] async def form(self) -> FormData: - if "form" not in self.scope: + if not hasattr(self, "_form"): assert ( parse_options_header is not None ), "The `python-multipart` library must be installed to use form parsing." @@ -245,17 +273,17 @@ async def form(self) -> FormData: content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": multipart_parser = MultiPartParser(self.headers, self.stream()) - self.scope["form"] = await multipart_parser.parse() + self._form = await multipart_parser.parse() elif content_type == b"application/x-www-form-urlencoded": form_parser = FormParser(self.headers, self.stream()) - self.scope["form"] = await form_parser.parse() + self._form = await form_parser.parse() else: - self.scope["form"] = FormData() - return self.scope["form"] + self._form = FormData() + return self._form async def close(self) -> None: - if "form" in self.scope: - await self.scope["form"].close() + if hasattr(self, "_form"): + await self._form.close() async def is_disconnected(self) -> bool: if not self._is_disconnected: From 3b72b773c77213d036ccf6c2e90be7c7ea68755a Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Sun, 30 Jan 2022 20:40:49 -0600 Subject: [PATCH 03/10] reverted unnecessary change to body stream iterator --- starlette/requests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/starlette/requests.py b/starlette/requests.py index c3e1cd5f7..643e0d15c 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -242,7 +242,10 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: async def body(self) -> bytes: if not hasattr(self, "_body"): - self._body = b"".join([chunk async for chunk in self.stream()]) + chunks = [] + async for chunk in self.stream(): + chunks.append(chunk) + self._body = b"".join(chunks) return self._body async def json(self) -> typing.Any: From dfeb14fd563b1bb6cf10d9a1e7f4ac63de5b8dc5 Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Sun, 30 Jan 2022 21:41:03 -0600 Subject: [PATCH 04/10] replaced body and form cache with raw stream cache --- starlette/requests.py | 84 ++++++++++++++++-------------------------- tests/test_requests.py | 7 +--- 2 files changed, 34 insertions(+), 57 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index 643e0d15c..abbe53f9b 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -191,7 +191,6 @@ def __init__( assert scope["type"] == "http" self._receive = receive self._send = send - self._stream_consumed = False self._is_disconnected = False @property @@ -203,85 +202,66 @@ def receive(self) -> Receive: return self._receive @property - def _body(self) -> bytes: + def _stream(self) -> bytes: try: - return self.scope["body"] + return self.scope["stream"] except KeyError: - raise AttributeError("_body") + raise AttributeError("_stream") - @_body.setter - def _body(self, bytes): - self.scope["body"] = bytes + @_stream.setter + def _stream(self, bytes): + self.scope["stream"] = bytes - @_body.deleter - def _body(self): - del self.scope["body"] + @_stream.deleter + def _stream(self): + del self.scope["stream"] async def stream(self) -> typing.AsyncGenerator[bytes, None]: - if hasattr(self, "_body"): - yield self._body - yield b"" + if hasattr(self, "_stream"): + for chunk in self._stream: + yield chunk return - if self._stream_consumed: - raise RuntimeError("Stream consumed") - - self._stream_consumed = True + self._stream = [] while True: message = await self._receive() if message["type"] == "http.request": body = message.get("body", b"") if body: + self._stream.append(body) yield body if not message.get("more_body", False): break elif message["type"] == "http.disconnect": self._is_disconnected = True raise ClientDisconnect() + self._stream.append(b"") yield b"" async def body(self) -> bytes: - if not hasattr(self, "_body"): - chunks = [] - async for chunk in self.stream(): - chunks.append(chunk) - self._body = b"".join(chunks) - return self._body + chunks = [] + async for chunk in self.stream(): + chunks.append(chunk) + return b"".join(chunks) async def json(self) -> typing.Any: body = await self.body() return json.loads(body) - @property - def _form(self) -> bytes: - try: - return self.scope["form"] - except KeyError: - raise AttributeError("_form") - - @_form.setter - def _form(self, bytes): - self.scope["form"] = bytes - - @_form.deleter - def _form(self): - del self.scope["form"] - async def form(self) -> FormData: - if not hasattr(self, "_form"): - assert ( - parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.headers.get("Content-Type") - content_type, options = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.headers, self.stream()) - self._form = await multipart_parser.parse() - elif content_type == b"application/x-www-form-urlencoded": - form_parser = FormParser(self.headers, self.stream()) - self._form = await form_parser.parse() - else: - self._form = FormData() + assert ( + parse_options_header is not None + ), "The `python-multipart` library must be installed to use form parsing." + content_type_header = self.headers.get("Content-Type") + content_type, options = parse_options_header(content_type_header) + if content_type == b"multipart/form-data": + multipart_parser = MultiPartParser(self.headers, self.stream()) + self._form = await multipart_parser.parse() + elif content_type == b"application/x-www-form-urlencoded": + form_parser = FormParser(self.headers, self.stream()) + self._form = await form_parser.parse() + else: + self._form = FormData() return self._form async def close(self) -> None: diff --git a/tests/test_requests.py b/tests/test_requests.py index 7c67cce3a..9d6ce8344 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -140,17 +140,14 @@ async def app(scope, receive, send): chunks = b"" async for chunk in request.stream(): chunks += chunk - try: - body = await request.body() - except RuntimeError: - body = b"" + body = await request.body() response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", data="abc") - assert response.json() == {"body": "", "stream": "abc"} + assert response.json() == {"body": "abc", "stream": "abc"} def test_request_body_then_request_body(test_client_factory): From febcbc97dbfb2db9c1a6172950f47e57ed7b00a9 Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Sun, 30 Jan 2022 21:50:09 -0600 Subject: [PATCH 05/10] moved scope stream to extensions --- starlette/requests.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index abbe53f9b..ea8ba9c4c 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -204,17 +204,24 @@ def receive(self) -> Receive: @property def _stream(self) -> bytes: try: - return self.scope["stream"] + return self.scope["extensions"]["starlette"]["stream"] except KeyError: raise AttributeError("_stream") @_stream.setter def _stream(self, bytes): - self.scope["stream"] = bytes + if "extensions" not in self.scope: + self.scope["extensions"] = {} + if "starlette" not in self.scope["extensions"]: + self.scope["extensions"]["starlette"] = {} + self.scope["extensions"]["starlette"]["stream"] = bytes @_stream.deleter def _stream(self): - del self.scope["stream"] + try: + del self.scope["extensions"]["starlette"]["stream"] + except KeyError: + raise AttributeError("_stream") async def stream(self) -> typing.AsyncGenerator[bytes, None]: if hasattr(self, "_stream"): From 196d8db4c059984f70ac77d5476a7f0202e4b7b7 Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Sun, 30 Jan 2022 22:12:49 -0600 Subject: [PATCH 06/10] fixed linting and coverage --- starlette/requests.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index ea8ba9c4c..979aea1cf 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -202,26 +202,19 @@ def receive(self) -> Receive: return self._receive @property - def _stream(self) -> bytes: + def _stream(self) -> typing.List[bytes]: try: return self.scope["extensions"]["starlette"]["stream"] except KeyError: raise AttributeError("_stream") @_stream.setter - def _stream(self, bytes): + def _stream(self, chunks: typing.List[bytes]) -> None: if "extensions" not in self.scope: self.scope["extensions"] = {} if "starlette" not in self.scope["extensions"]: self.scope["extensions"]["starlette"] = {} - self.scope["extensions"]["starlette"]["stream"] = bytes - - @_stream.deleter - def _stream(self): - try: - del self.scope["extensions"]["starlette"]["stream"] - except KeyError: - raise AttributeError("_stream") + self.scope["extensions"]["starlette"]["stream"] = chunks async def stream(self) -> typing.AsyncGenerator[bytes, None]: if hasattr(self, "_stream"): From 16f0fcfd62c5ea3afb4ca96c2b6ed7433cabaf18 Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Mon, 31 Jan 2022 08:22:41 -0600 Subject: [PATCH 07/10] reverted stream caching to prevent large streams exhausting memory --- starlette/requests.py | 26 ++++++++++++++++++-------- tests/test_requests.py | 7 +++++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index 979aea1cf..b93a1dec2 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -191,6 +191,7 @@ def __init__( assert scope["type"] == "http" self._receive = receive self._send = send + self._stream_consumed = False self._is_disconnected = False @property @@ -210,37 +211,46 @@ def _stream(self) -> typing.List[bytes]: @_stream.setter def _stream(self, chunks: typing.List[bytes]) -> None: - if "extensions" not in self.scope: - self.scope["extensions"] = {} if "starlette" not in self.scope["extensions"]: self.scope["extensions"]["starlette"] = {} self.scope["extensions"]["starlette"]["stream"] = chunks + async def _wrap_stream(self) -> typing.AsyncGenerator[bytes, None]: + _stream: typing.List[bytes] = [] + + async for chunk in self.stream(): + _stream.append(chunk) + yield chunk + + if not hasattr(self, "_stream"): + self._stream = _stream + async def stream(self) -> typing.AsyncGenerator[bytes, None]: if hasattr(self, "_stream"): for chunk in self._stream: yield chunk return - self._stream = [] + if self._stream_consumed: + raise RuntimeError("Stream consumed") + + self._stream_consumed = True while True: message = await self._receive() if message["type"] == "http.request": body = message.get("body", b"") if body: - self._stream.append(body) yield body if not message.get("more_body", False): break elif message["type"] == "http.disconnect": self._is_disconnected = True raise ClientDisconnect() - self._stream.append(b"") yield b"" async def body(self) -> bytes: chunks = [] - async for chunk in self.stream(): + async for chunk in self._wrap_stream(): chunks.append(chunk) return b"".join(chunks) @@ -255,10 +265,10 @@ async def form(self) -> FormData: content_type_header = self.headers.get("Content-Type") content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.headers, self.stream()) + multipart_parser = MultiPartParser(self.headers, self._wrap_stream()) self._form = await multipart_parser.parse() elif content_type == b"application/x-www-form-urlencoded": - form_parser = FormParser(self.headers, self.stream()) + form_parser = FormParser(self.headers, self._wrap_stream()) self._form = await form_parser.parse() else: self._form = FormData() diff --git a/tests/test_requests.py b/tests/test_requests.py index 9d6ce8344..7c67cce3a 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -140,14 +140,17 @@ async def app(scope, receive, send): chunks = b"" async for chunk in request.stream(): chunks += chunk - body = await request.body() + try: + body = await request.body() + except RuntimeError: + body = b"" response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", data="abc") - assert response.json() == {"body": "abc", "stream": "abc"} + assert response.json() == {"body": "", "stream": "abc"} def test_request_body_then_request_body(test_client_factory): From d8dc80ded313841a7d09306dcc054daf222455cb Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Mon, 31 Jan 2022 08:23:59 -0600 Subject: [PATCH 08/10] renamed private stream cache wrapper --- starlette/requests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index b93a1dec2..c9a7c7841 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -215,7 +215,7 @@ def _stream(self, chunks: typing.List[bytes]) -> None: self.scope["extensions"]["starlette"] = {} self.scope["extensions"]["starlette"]["stream"] = chunks - async def _wrap_stream(self) -> typing.AsyncGenerator[bytes, None]: + async def _cache_stream(self) -> typing.AsyncGenerator[bytes, None]: _stream: typing.List[bytes] = [] async for chunk in self.stream(): @@ -250,7 +250,7 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: async def body(self) -> bytes: chunks = [] - async for chunk in self._wrap_stream(): + async for chunk in self._cache_stream(): chunks.append(chunk) return b"".join(chunks) @@ -265,10 +265,10 @@ async def form(self) -> FormData: content_type_header = self.headers.get("Content-Type") content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.headers, self._wrap_stream()) + multipart_parser = MultiPartParser(self.headers, self._cache_stream()) self._form = await multipart_parser.parse() elif content_type == b"application/x-www-form-urlencoded": - form_parser = FormParser(self.headers, self._wrap_stream()) + form_parser = FormParser(self.headers, self._cache_stream()) self._form = await form_parser.parse() else: self._form = FormData() From c73a078687b9dfe3b5071259fdf19d6533d176c0 Mon Sep 17 00:00:00 2001 From: Nick Harris <3432064+nikordaris@users.noreply.github.com> Date: Mon, 14 Feb 2022 13:25:16 -0600 Subject: [PATCH 09/10] Apply suggestions from code review Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 1bb300c91..bff8fa258 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ test.db .coverage .pytest_cache/ -.python-version .mypy_cache/ __pycache__/ htmlcov/ From 510db9abdf407b2de410135311f2cc1a3f360136 Mon Sep 17 00:00:00 2001 From: Nick Harris Date: Mon, 14 Feb 2022 13:37:21 -0600 Subject: [PATCH 10/10] reverted Request form caching --- starlette/requests.py | 27 ++++++++++++++------------- tests/test_requests.py | 15 --------------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index c9a7c7841..61d4a651b 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -259,19 +259,20 @@ async def json(self) -> typing.Any: return json.loads(body) async def form(self) -> FormData: - assert ( - parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.headers.get("Content-Type") - content_type, options = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.headers, self._cache_stream()) - self._form = await multipart_parser.parse() - elif content_type == b"application/x-www-form-urlencoded": - form_parser = FormParser(self.headers, self._cache_stream()) - self._form = await form_parser.parse() - else: - self._form = FormData() + if not hasattr(self, "_form"): + assert ( + parse_options_header is not None + ), "The `python-multipart` library must be installed to use form parsing." + content_type_header = self.headers.get("Content-Type") + content_type, options = parse_options_header(content_type_header) + if content_type == b"multipart/form-data": + multipart_parser = MultiPartParser(self.headers, self.stream()) + self._form = await multipart_parser.parse() + elif content_type == b"application/x-www-form-urlencoded": + form_parser = FormParser(self.headers, self.stream()) + self._form = await form_parser.parse() + else: + self._form = FormData() return self._form async def close(self) -> None: diff --git a/tests/test_requests.py b/tests/test_requests.py index 7c67cce3a..8eb4a1631 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -168,21 +168,6 @@ async def app(scope, receive, send): assert response.json() == {"body": "abc", "body2": "abc"} -def test_request_form_then_request_form(test_client_factory): - async def app(scope, receive, send): - request = Request(scope, receive) - form = await request.form() - request2 = Request(scope, request.receive) - form2 = await request2.form() - response = JSONResponse({"form": dict(form), "form2": dict(form2)}) - await response(scope, receive, send) - - client = test_client_factory(app) - - response = client.post("/", data={"abc": "123 @"}) - assert response.json() == {"form": {"abc": "123 @"}, "form2": {"abc": "123 @"}} - - def test_request_json(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive)