Skip to content

Commit

Permalink
Merge pull request #3142 from morgan9e/patch-1
Browse files Browse the repository at this point in the history
Python: Fix header conversion to byte-pair on scope building
  • Loading branch information
hoodmane authored Dec 3, 2024
2 parents a866c69 + dfc7963 commit e6768de
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/pyodide/internal/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def acquire_js_buffer(pybuffer):
def request_to_scope(req, env, ws=False):
from js import URL

headers = [tuple(x) for x in req.headers]
# @app.get("/example")
# async def example(request: Request):
# request.headers.get("content-type")
# - this will error if header is not "bytes" as in ASGI spec.
headers = [(k.lower().encode(), v.encode()) for k, v in req.headers]
url = URL.new(req.url)
assert url.protocol[-1] == ":"
scheme = url.protocol[:-1]
Expand Down Expand Up @@ -112,8 +116,13 @@ async def process_request(app, req, env):
result = Future()

async def response_gen():
async for data in req.body:
yield {"body": data.to_bytes(), "more_body": True, "type": "http.request"}
if req.body:
async for data in req.body:
yield {
"body": data.to_bytes(),
"more_body": True,
"type": "http.request",
}
yield {"body": b"", "more_body": False, "type": "http.request"}

responses = response_gen()
Expand All @@ -126,7 +135,8 @@ async def send(got):
nonlocal headers
if got["type"] == "http.response.start":
status = got["status"]
headers = got["headers"]
# Like above, we need to convert byte-pairs into string explicitly.
headers = [(k.decode(), v.decode()) for k, v in got["headers"]]
if got["type"] == "http.response.body":
# intentionally leak body to avoid a copy
#
Expand Down
11 changes: 11 additions & 0 deletions src/workerd/server/tests/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ py_wd_test(
),
)

py_wd_test(
src = "asgi/asgi.wd-test",
args = ["--experimental"],
data = glob(
[
"asgi/*",
],
exclude = ["**/*.wd-test"],
),
)

py_wd_test(
src = "random/random.wd-test",
args = ["--experimental"],
Expand Down
19 changes: 19 additions & 0 deletions src/workerd/server/tests/python/asgi/asgi.wd-test
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Workerd = import "/workerd/workerd.capnp";

const unitTests :Workerd.Config = (
services = [
( name = "python-asgi",
worker = (
modules = [
(name = "worker.py", pythonModule = embed "worker.py"),
(name = "fastapi", pythonRequirement = ""),
],
bindings = [
( name = "SELF", service = "python-asgi" ),
],
compatibilityDate = "2024-10-01",
compatibilityFlags = ["python_workers_development"],
)
)
],
);
71 changes: 71 additions & 0 deletions src/workerd/server/tests/python/asgi/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from pyodide.ffi import to_js


def check_encoding(byte_str, encoding="utf-8"):
try:
byte_str.decode(encoding)
except UnicodeDecodeError:
return False
return True


class Server:
def __init__(self):
pass

async def __call__(self, scope, receive, send) -> None:
scope["app"] = self

assert scope["type"] in ("http", "websocket", "lifespan")

if scope["type"] == "lifespan":
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
return

elif scope["type"] == "http":
headers = scope["headers"]
for header in headers:
assert isinstance(header[0], bytes) and isinstance(header[1], bytes)
assert check_encoding(header[0]) and check_encoding(header[1])

await receive()
# Send response and return
await send(
{
"type": "http.response.start",
"status": 200,
"headers": headers,
}
)

await send(
{
"type": "http.response.body",
"body": b"Hello, World",
}
)


async def on_fetch(request, env):
import asgi

return await asgi.fetch(app, request, env)


app = Server()


async def header_test(env):
example_hdr = {"Header1": "Value1", "Header2": "Value2"}
response = await env.SELF.fetch("http://example.com/", headers=to_js(example_hdr))
for header in response.headers:
assert isinstance(header[0], str) and isinstance(header[1], str)
expected_hdr = {k.lower(): v.lower() for k, v in example_hdr.items()}
assert header[0] in expected_hdr.keys()
assert expected_hdr[header[0]] == header[1].lower()


async def test(ctrl, env):
await header_test(env)

0 comments on commit e6768de

Please sign in to comment.