From cebe526b9c34dc3a3da9140409db63014bc4cf19 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 7 Apr 2024 13:19:31 +0100 Subject: [PATCH] Fix handling of multipart/form-data (#8280) (#8302) https://datatracker.ietf.org/doc/html/rfc7578 (cherry picked from commit 7d0be3fee540a3d4161ac7dc76422f1f5ea60104) --- CHANGES/8280.bugfix.rst | 1 + CHANGES/8280.deprecation.rst | 2 + aiohttp/formdata.py | 12 +++- aiohttp/multipart.py | 121 +++++++++++++++++++++----------- tests/test_client_functional.py | 44 +----------- tests/test_multipart.py | 68 ++++++++++++++---- tests/test_web_functional.py | 27 ++----- 7 files changed, 155 insertions(+), 120 deletions(-) create mode 100644 CHANGES/8280.bugfix.rst create mode 100644 CHANGES/8280.deprecation.rst diff --git a/CHANGES/8280.bugfix.rst b/CHANGES/8280.bugfix.rst new file mode 100644 index 00000000000..3aebe36fe9e --- /dev/null +++ b/CHANGES/8280.bugfix.rst @@ -0,0 +1 @@ +Fixed ``multipart/form-data`` compliance with :rfc:`7578` -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/8280.deprecation.rst b/CHANGES/8280.deprecation.rst new file mode 100644 index 00000000000..302dbb2fe2a --- /dev/null +++ b/CHANGES/8280.deprecation.rst @@ -0,0 +1,2 @@ +Deprecated ``content_transfer_encoding`` parameter in :py:meth:`FormData.add_field() +` -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py index e7cd24ca9f7..2b75b3de72c 100644 --- a/aiohttp/formdata.py +++ b/aiohttp/formdata.py @@ -1,4 +1,5 @@ import io +import warnings from typing import Any, Iterable, List, Optional from urllib.parse import urlencode @@ -53,7 +54,12 @@ def add_field( if isinstance(value, io.IOBase): self._is_multipart = True elif isinstance(value, (bytes, bytearray, memoryview)): + msg = ( + "In v4, passing bytes will no longer create a file field. " + "Please explicitly use the filename parameter or pass a BytesIO object." + ) if filename is None and content_transfer_encoding is None: + warnings.warn(msg, DeprecationWarning) filename = name type_options: MultiDict[str] = MultiDict({"name": name}) @@ -81,7 +87,11 @@ def add_field( "content_transfer_encoding must be an instance" " of str. Got: %s" % content_transfer_encoding ) - headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding + msg = ( + "content_transfer_encoding is deprecated. " + "To maintain compatibility with v4 please pass a BytesPayload." + ) + warnings.warn(msg, DeprecationWarning) self._is_multipart = True self._fields.append((type_options, headers, value)) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 4471dd4bb7e..a43ec545713 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -256,13 +256,22 @@ class BodyPartReader: chunk_size = 8192 def __init__( - self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader + self, + boundary: bytes, + headers: "CIMultiDictProxy[str]", + content: StreamReader, + *, + subtype: str = "mixed", + default_charset: Optional[str] = None, ) -> None: self.headers = headers self._boundary = boundary self._content = content + self._default_charset = default_charset self._at_eof = False - length = self.headers.get(CONTENT_LENGTH, None) + self._is_form_data = subtype == "form-data" + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 + length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None) self._length = int(length) if length is not None else None self._read_bytes = 0 self._unread: Deque[bytes] = deque() @@ -329,6 +338,8 @@ async def _read_chunk_from_length(self, size: int) -> bytes: assert self._length is not None, "Content-Length required for chunked read" chunk_size = min(size, self._length - self._read_bytes) chunk = await self._content.read(chunk_size) + if self._content.at_eof(): + self._at_eof = True return chunk async def _read_chunk_from_stream(self, size: int) -> bytes: @@ -449,7 +460,8 @@ def decode(self, data: bytes) -> bytes: """ if CONTENT_TRANSFER_ENCODING in self.headers: data = self._decode_content_transfer(data) - if CONTENT_ENCODING in self.headers: + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 + if not self._is_form_data and CONTENT_ENCODING in self.headers: return self._decode_content(data) return data @@ -483,7 +495,7 @@ def get_charset(self, default: str) -> str: """Returns charset parameter from Content-Type header or default.""" ctype = self.headers.get(CONTENT_TYPE, "") mimetype = parse_mimetype(ctype) - return mimetype.parameters.get("charset", default) + return mimetype.parameters.get("charset", self._default_charset or default) @reify def name(self) -> Optional[str]: @@ -538,9 +550,17 @@ class MultipartReader: part_reader_cls = BodyPartReader def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: + self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) + assert self._mimetype.type == "multipart", "multipart/* content type expected" + if "boundary" not in self._mimetype.parameters: + raise ValueError( + "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE] + ) + self.headers = headers self._boundary = ("--" + self._get_boundary()).encode() self._content = content + self._default_charset: Optional[str] = None self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None self._at_eof = False self._at_bof = True @@ -592,7 +612,24 @@ async def next( await self._read_boundary() if self._at_eof: # we just read the last boundary, nothing to do there return None - self._last_part = await self.fetch_next_part() + + part = await self.fetch_next_part() + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6 + if ( + self._last_part is None + and self._mimetype.subtype == "form-data" + and isinstance(part, BodyPartReader) + ): + _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION)) + if params.get("name") == "_charset_": + # Longest encoding in https://encoding.spec.whatwg.org/encodings.json + # is 19 characters, so 32 should be more than enough for any valid encoding. + charset = await part.read_chunk(32) + if len(charset) > 31: + raise RuntimeError("Invalid default charset") + self._default_charset = charset.strip().decode() + part = await self.fetch_next_part() + self._last_part = part return self._last_part async def release(self) -> None: @@ -628,19 +665,16 @@ def _get_part_reader( return type(self)(headers, self._content) return self.multipart_reader_cls(headers, self._content) else: - return self.part_reader_cls(self._boundary, headers, self._content) - - def _get_boundary(self) -> str: - mimetype = parse_mimetype(self.headers[CONTENT_TYPE]) - - assert mimetype.type == "multipart", "multipart/* content type expected" - - if "boundary" not in mimetype.parameters: - raise ValueError( - "boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE] + return self.part_reader_cls( + self._boundary, + headers, + self._content, + subtype=self._mimetype.subtype, + default_charset=self._default_charset, ) - boundary = mimetype.parameters["boundary"] + def _get_boundary(self) -> str: + boundary = self._mimetype.parameters["boundary"] if len(boundary) > 70: raise ValueError("boundary %r is too long (70 chars max)" % boundary) @@ -731,6 +765,7 @@ def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> No super().__init__(None, content_type=ctype) self._parts: List[_Part] = [] + self._is_form_data = subtype == "form-data" def __enter__(self) -> "MultipartWriter": return self @@ -808,32 +843,36 @@ def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Paylo def append_payload(self, payload: Payload) -> Payload: """Adds a new body part to multipart writer.""" - # compression - encoding: Optional[str] = payload.headers.get( - CONTENT_ENCODING, - "", - ).lower() - if encoding and encoding not in ("deflate", "gzip", "identity"): - raise RuntimeError(f"unknown content encoding: {encoding}") - if encoding == "identity": - encoding = None - - # te encoding - te_encoding: Optional[str] = payload.headers.get( - CONTENT_TRANSFER_ENCODING, - "", - ).lower() - if te_encoding not in ("", "base64", "quoted-printable", "binary"): - raise RuntimeError( - "unknown content transfer encoding: {}" "".format(te_encoding) + encoding: Optional[str] = None + te_encoding: Optional[str] = None + if self._is_form_data: + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7 + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 + assert CONTENT_DISPOSITION in payload.headers + assert "name=" in payload.headers[CONTENT_DISPOSITION] + assert ( + not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING} + & payload.headers.keys() ) - if te_encoding == "binary": - te_encoding = None - - # size - size = payload.size - if size is not None and not (encoding or te_encoding): - payload.headers[CONTENT_LENGTH] = str(size) + else: + # compression + encoding = payload.headers.get(CONTENT_ENCODING, "").lower() + if encoding and encoding not in ("deflate", "gzip", "identity"): + raise RuntimeError(f"unknown content encoding: {encoding}") + if encoding == "identity": + encoding = None + + # te encoding + te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() + if te_encoding not in ("", "base64", "quoted-printable", "binary"): + raise RuntimeError(f"unknown content transfer encoding: {te_encoding}") + if te_encoding == "binary": + te_encoding = None + + # size + size = payload.size + if size is not None and not (encoding or te_encoding): + payload.headers[CONTENT_LENGTH] = str(size) self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type] return payload diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 8a9a4e184be..dbb2dff5ac4 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1317,48 +1317,6 @@ async def handler(request): resp.close() -async def test_POST_DATA_with_context_transfer_encoding(aiohttp_client) -> None: - async def handler(request): - data = await request.post() - assert data["name"] == "text" - return web.Response(text=data["name"]) - - app = web.Application() - app.router.add_post("/", handler) - client = await aiohttp_client(app) - - form = aiohttp.FormData() - form.add_field("name", "text", content_transfer_encoding="base64") - - resp = await client.post("/", data=form) - assert 200 == resp.status - content = await resp.text() - assert content == "text" - resp.close() - - -async def test_POST_DATA_with_content_type_context_transfer_encoding(aiohttp_client): - async def handler(request): - data = await request.post() - assert data["name"] == "text" - return web.Response(body=data["name"]) - - app = web.Application() - app.router.add_post("/", handler) - client = await aiohttp_client(app) - - form = aiohttp.FormData() - form.add_field( - "name", "text", content_type="text/plain", content_transfer_encoding="base64" - ) - - resp = await client.post("/", data=form) - assert 200 == resp.status - content = await resp.text() - assert content == "text" - resp.close() - - async def test_POST_MultiDict(aiohttp_client) -> None: async def handler(request): data = await request.post() @@ -1410,7 +1368,7 @@ async def handler(request): with fname.open("rb") as f: async with client.post( - "/", data={"some": f, "test": b"data"}, chunked=True + "/", data={"some": f, "test": io.BytesIO(b"data")}, chunked=True ) as resp: assert 200 == resp.status diff --git a/tests/test_multipart.py b/tests/test_multipart.py index f9d130e7949..dbfaf74b9b7 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -944,6 +944,58 @@ async def test_reading_skips_prelude(self) -> None: assert first.at_eof() assert not second.at_eof() + async def test_read_form_default_encoding(self) -> None: + with Stream( + b"--:\r\n" + b'Content-Disposition: form-data; name="_charset_"\r\n\r\n' + b"ascii" + b"\r\n" + b"--:\r\n" + b'Content-Disposition: form-data; name="field1"\r\n\r\n' + b"foo" + b"\r\n" + b"--:\r\n" + b"Content-Type: text/plain;charset=UTF-8\r\n" + b'Content-Disposition: form-data; name="field2"\r\n\r\n' + b"foo" + b"\r\n" + b"--:\r\n" + b'Content-Disposition: form-data; name="field3"\r\n\r\n' + b"foo" + b"\r\n" + ) as stream: + reader = aiohttp.MultipartReader( + {CONTENT_TYPE: 'multipart/form-data;boundary=":"'}, + stream, + ) + field1 = await reader.next() + assert field1.name == "field1" + assert field1.get_charset("default") == "ascii" + field2 = await reader.next() + assert field2.name == "field2" + assert field2.get_charset("default") == "UTF-8" + field3 = await reader.next() + assert field3.name == "field3" + assert field3.get_charset("default") == "ascii" + + async def test_read_form_invalid_default_encoding(self) -> None: + with Stream( + b"--:\r\n" + b'Content-Disposition: form-data; name="_charset_"\r\n\r\n' + b"this-value-is-too-long-to-be-a-charset" + b"\r\n" + b"--:\r\n" + b'Content-Disposition: form-data; name="field1"\r\n\r\n' + b"foo" + b"\r\n" + ) as stream: + reader = aiohttp.MultipartReader( + {CONTENT_TYPE: 'multipart/form-data;boundary=":"'}, + stream, + ) + with pytest.raises(RuntimeError, match="Invalid default charset"): + await reader.next() + async def test_writer(writer) -> None: assert writer.size == 7 @@ -1280,7 +1332,6 @@ async def test_preserve_content_disposition_header(self, buf, stream): CONTENT_TYPE: "text/python", }, ) - content_length = part.size await writer.write(stream) assert part.headers[CONTENT_TYPE] == "text/python" @@ -1291,9 +1342,7 @@ async def test_preserve_content_disposition_header(self, buf, stream): assert headers == ( b"--:\r\n" b"Content-Type: text/python\r\n" - b'Content-Disposition: attachments; filename="bug.py"\r\n' - b"Content-Length: %s" - b"" % (str(content_length).encode(),) + b'Content-Disposition: attachments; filename="bug.py"' ) async def test_set_content_disposition_override(self, buf, stream): @@ -1307,7 +1356,6 @@ async def test_set_content_disposition_override(self, buf, stream): CONTENT_TYPE: "text/python", }, ) - content_length = part.size await writer.write(stream) assert part.headers[CONTENT_TYPE] == "text/python" @@ -1318,9 +1366,7 @@ async def test_set_content_disposition_override(self, buf, stream): assert headers == ( b"--:\r\n" b"Content-Type: text/python\r\n" - b'Content-Disposition: attachments; filename="bug.py"\r\n' - b"Content-Length: %s" - b"" % (str(content_length).encode(),) + b'Content-Disposition: attachments; filename="bug.py"' ) async def test_reset_content_disposition_header(self, buf, stream): @@ -1332,8 +1378,6 @@ async def test_reset_content_disposition_header(self, buf, stream): headers={CONTENT_TYPE: "text/plain"}, ) - content_length = part.size - assert CONTENT_DISPOSITION in part.headers part.set_content_disposition("attachments", filename="bug.py") @@ -1346,9 +1390,7 @@ async def test_reset_content_disposition_header(self, buf, stream): b"--:\r\n" b"Content-Type: text/plain\r\n" b"Content-Disposition:" - b' attachments; filename="bug.py"\r\n' - b"Content-Length: %s" - b"" % (str(content_length).encode(),) + b' attachments; filename="bug.py"' ) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 04fc2e35fd1..ee61537068b 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -48,7 +48,8 @@ def fname(here): def new_dummy_form(): form = FormData() - form.add_field("name", b"123", content_transfer_encoding="base64") + with pytest.warns(DeprecationWarning, match="BytesPayload"): + form.add_field("name", b"123", content_transfer_encoding="base64") return form @@ -447,25 +448,6 @@ async def handler(request): await resp.release() -async def test_POST_DATA_with_content_transfer_encoding(aiohttp_client) -> None: - async def handler(request): - data = await request.post() - assert b"123" == data["name"] - return web.Response() - - app = web.Application() - app.router.add_post("/", handler) - client = await aiohttp_client(app) - - form = FormData() - form.add_field("name", b"123", content_transfer_encoding="base64") - - resp = await client.post("/", data=form) - assert 200 == resp.status - - await resp.release() - - async def test_post_form_with_duplicate_keys(aiohttp_client) -> None: async def handler(request): data = await request.post() @@ -523,7 +505,8 @@ async def handler(request): return web.Response() form = FormData() - form.add_field("name", b"123", content_transfer_encoding="base64") + with pytest.warns(DeprecationWarning, match="BytesPayload"): + form.add_field("name", b"123", content_transfer_encoding="base64") app = web.Application() app.router.add_post("/", handler) @@ -727,7 +710,7 @@ async def handler(request): app.router.add_post("/", handler) client = await aiohttp_client(app) - resp = await client.post("/", data={"file": data}) + resp = await client.post("/", data={"file": io.BytesIO(data)}) assert 200 == resp.status await resp.release()