From 02735353dfdca0729d6f8c2b756437363e8de309 Mon Sep 17 00:00:00 2001 From: Adam Bannister Date: Mon, 18 Nov 2019 10:07:55 +0100 Subject: [PATCH] Form data processed (#4351) --- CHANGES/4345.bugfix | 1 + aiohttp/formdata.py | 4 ++++ tests/test_formdata.py | 15 ++++++++++++++- tests/test_web_functional.py | 21 ++++++++++----------- 4 files changed, 29 insertions(+), 12 deletions(-) create mode 100644 CHANGES/4345.bugfix diff --git a/CHANGES/4345.bugfix b/CHANGES/4345.bugfix new file mode 100644 index 00000000000..badaf6453eb --- /dev/null +++ b/CHANGES/4345.bugfix @@ -0,0 +1 @@ +Raise ClientPayloadError if FormData re-processed. diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py index b4ffa048f37..811926901dd 100644 --- a/aiohttp/formdata.py +++ b/aiohttp/formdata.py @@ -22,6 +22,7 @@ def __init__(self, fields: self._writer = multipart.MultipartWriter('form-data') self._fields = [] # type: List[Any] self._is_multipart = False + self._is_processed = False self._quote_fields = quote_fields self._charset = charset @@ -115,6 +116,8 @@ def _gen_form_urlencoded(self) -> payload.BytesPayload: def _gen_form_data(self) -> multipart.MultipartWriter: """Encode a list of fields using the multipart/form-data MIME format""" + if self._is_processed: + raise RuntimeError('Form data has been processed already') for dispparams, headers, value in self._fields: try: if hdrs.CONTENT_TYPE in headers: @@ -141,6 +144,7 @@ def _gen_form_data(self) -> multipart.MultipartWriter: self._writer.append_payload(part) + self._is_processed = True return self._writer def __call__(self) -> Payload: diff --git a/tests/test_formdata.py b/tests/test_formdata.py index 55f8653d6d6..88cfc0456be 100644 --- a/tests/test_formdata.py +++ b/tests/test_formdata.py @@ -2,7 +2,7 @@ import pytest -from aiohttp.formdata import FormData +from aiohttp import ClientSession, FormData @pytest.fixture @@ -86,3 +86,16 @@ async def test_formdata_field_name_is_not_quoted(buf, writer) -> None: payload = form() await payload.write(writer) assert b'name="emails[]"' in buf + + +async def test_mark_formdata_as_processed() -> None: + async with ClientSession() as session: + url = "http://httpbin.org/anything" + data = FormData() + data.add_field("test", "test_value", content_type="application/json") + + await session.post(url, data=data) + assert len(data._writer._parts) == 1 + + with pytest.raises(RuntimeError): + await session.post(url, data=data) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index b9c1a3f7a34..4be7a962303 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -37,6 +37,13 @@ def fname(here): return here / 'conftest.py' +def new_dummy_form(): + form = FormData() + form.add_field('name', b'123', + content_transfer_encoding='base64') + return form + + async def test_simple_get(aiohttp_client) -> None: async def handler(request): @@ -513,15 +520,11 @@ async def expect_handler(request): if request.version == HttpVersion11: await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") - form = FormData() - form.add_field('name', b'123', - content_transfer_encoding='base64') - app = web.Application() app.router.add_post('/', handler, expect_handler=expect_handler) client = await aiohttp_client(app) - resp = await client.post('/', data=form, expect100=True) + resp = await client.post('/', data=new_dummy_form(), expect100=True) assert 200 == resp.status assert expect_received @@ -540,20 +543,16 @@ async def expect_handler(request): await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") - form = FormData() - form.add_field('name', b'123', - content_transfer_encoding='base64') - app = web.Application() app.router.add_post('/', handler, expect_handler=expect_handler) client = await aiohttp_client(app) auth_err = False - resp = await client.post('/', data=form, expect100=True) + resp = await client.post('/', data=new_dummy_form(), expect100=True) assert 200 == resp.status auth_err = True - resp = await client.post('/', data=form, expect100=True) + resp = await client.post('/', data=new_dummy_form(), expect100=True) assert 403 == resp.status