Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
abersheeran committed Jun 26, 2024
1 parent 39c65db commit 5308867
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 24 deletions.
52 changes: 32 additions & 20 deletions a2wsgi/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,26 +152,28 @@ def __init__(
self.wait_time = wait_time

self.sync_event = SyncEvent()
self.async_event = AsyncEvent(loop)
self.async_lock: asyncio.Lock
self.sync_event_set_lock: asyncio.Lock

self.receive_event = AsyncEvent(loop)
self.send_event = AsyncEvent(loop)

def _init_async_lock():
self.async_lock = asyncio.Lock()
self.sync_event_set_lock = asyncio.Lock()

loop.call_soon_threadsafe(_init_async_lock)

self.asgi_done = threading.Event()
self.wsgi_should_stop: bool = False

async def asgi_receive(self) -> ReceiveEvent:
async with self.async_lock:
self.sync_event.set({"type": "receive"})
return await self.async_event.wait()
await self.sync_event_set_lock.acquire()
self.sync_event.set({"type": "receive"})
return await self.receive_event.wait()

async def asgi_send(self, message: SendEvent) -> None:
async with self.async_lock:
self.sync_event.set(message)
await self.async_event.wait()
await self.sync_event_set_lock.acquire()
self.sync_event.set(message)
await self.send_event.wait()

def asgi_done_callback(self, future: asyncio.Future) -> None:
try:
Expand Down Expand Up @@ -209,13 +211,16 @@ def __call__(
read_count: int = 0
body = environ["wsgi.input"] or BytesIO()
content_length = int(environ.get("CONTENT_LENGTH", None) or 0)
receive_eof = False
body_sent = False

asgi_task = self.start_asgi_app(environ)
# activate loop
self.loop.call_soon_threadsafe(lambda: None)

while True:
message = self.sync_event.wait()
self.loop.call_soon_threadsafe(self.sync_event_set_lock.release)
message_type = message["type"]

if message_type == "http.response.start":
Expand All @@ -230,13 +235,21 @@ def __call__(
],
None,
)
self.send_event.set(None)
elif message_type == "http.response.body":
yield message.get("body", b"")
body_sent = True
self.wsgi_should_stop = not message.get("more_body", False)
self.send_event.set(None)
elif message_type == "http.response.disconnect":
self.wsgi_should_stop = True
self.send_event.set(None)
# ASGI application error
elif message_type == "a2wsgi.error":
if body_sent:
raise message["exception"][1].with_traceback(
message["exception"][2]
)
start_response(
"500 Internal Server Error",
[
Expand All @@ -248,28 +261,27 @@ def __call__(
yield b"Server got itself in trouble"
self.wsgi_should_stop = True
elif message_type == "receive":
pass
else:
raise RuntimeError(f"Unknown message type: {message_type}")

if message_type == "receive":
read_size = min(65536, content_length - read_count)
if read_size == 0: # No more body, so don't read anymore
self.async_event.set(
{"type": "http.request", "body": b"", "more_body": False}
)
if not receive_eof:
self.receive_event.set(
{"type": "http.request", "body": b"", "more_body": False}
)
receive_eof = True
else:
pass # let `await receive()` wait
else:
data: bytes = body.read(read_size)
read_count += len(data)
more_body = read_count < content_length
self.async_event.set(
self.receive_event.set(
{"type": "http.request", "body": data, "more_body": more_body}
)
else:
self.async_event.set(None)
raise RuntimeError(f"Unknown message type: {message_type}")

if self.wsgi_should_stop:
self.async_event.set_nowait()
self.receive_event.set({"type": "http.disconnect"})
break

if asgi_task.done():
Expand Down
26 changes: 22 additions & 4 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ async def hello_world(scope, receive, send):
"status": 200,
"headers": [
[b"content-type", b"text/plain"],
[b"content-length", b"13"],
],
}
)
await send(
{"type": "http.response.body", "body": b"Hello, world!", "more_body": True}
)
await send({"type": "http.response.disconnect"})
await send({"type": "http.response.body", "body": b"Hello, world!"})


async def echo_body(scope, receive, send):
Expand Down Expand Up @@ -206,3 +204,23 @@ def test_starlette_stream_response():
response = client.get("/")
assert response.status_code == 200
assert response.text == "0123456789"


def test_starlette_base_http_middleware():
from starlette.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

class Middleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers["x-middleware"] = "true"
return response

app = ASGIMiddleware(Middleware(JSONResponse({"hello": "world"})))
with httpx.Client(
transport=httpx.WSGITransport(app=app), base_url="http://testserver:80"
) as client:
response = client.get("/")
assert response.status_code == 200
assert response.text == '{"hello":"world"}'
assert response.headers["x-middleware"] == "true"

0 comments on commit 5308867

Please sign in to comment.