Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions newrelic/api/asgi_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,20 @@ async def send_inject_browser_agent(self, message):

message_type = message["type"]
if message_type == "http.response.start" and not self.initial_message:
headers = list(message.get("headers", ()))
# message["headers"] may be a generator, and consuming it via process_response will leave the original
# application with no headers. Fix this by preserving them in a list before consuming them.
if "headers" in message:
message["headers"] = headers = list(message["headers"])
else:
headers = []

# Check if we should insert the HTML snippet based on the headers.
# Currently if there are no headers this will always be False, but call the function
# anyway in case this logic changes in the future.
if not self.should_insert_html(headers):
await self.abort()
return

message["headers"] = headers
self.initial_message = message
elif message_type == "http.response.body" and self.initial_message:
Expand Down Expand Up @@ -232,7 +242,13 @@ async def send(self, event):
finally:
self.__exit__(*sys.exc_info())
elif event["type"] == "http.response.start":
self.process_response(event["status"], event.get("headers", ()))
# event["headers"] may be a generator, and consuming it via process_response will leave the original
# ASGI application with no headers. Fix this by preserving them in a list before consuming them.
if "headers" in event:
event["headers"] = headers = list(event["headers"])
else:
headers = []
self.process_response(event["status"], headers)
return await self._send(event)


Expand Down
2 changes: 1 addition & 1 deletion newrelic/hooks/framework_sanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def _nr_sanic_response_send(wrapped, instance, args, kwargs):
transaction = current_transaction()
result = wrapped(*args, **kwargs)
if isawaitable(result):
await result
result = await result

if transaction is None:
return result
Expand Down
24 changes: 24 additions & 0 deletions tests/agent_features/test_asgi_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from testing_support.fixtures import override_application_settings
from testing_support.sample_asgi_applications import (
AppWithDescriptor,
asgi_application_generator_headers,
simple_app_v2,
simple_app_v2_init_exc,
simple_app_v2_raw,
Expand All @@ -37,6 +38,7 @@
simple_app_v3_wrapped = AsgiTest(simple_app_v3)
simple_app_v2_wrapped = AsgiTest(simple_app_v2)
simple_app_v2_init_exc = AsgiTest(simple_app_v2_init_exc)
asgi_application_generator_headers = AsgiTest(asgi_application_generator_headers)


# Test naming scheme logic and ASGIApplicationWrapper for a single callable
Expand Down Expand Up @@ -85,6 +87,28 @@ def test_double_callable_raw():
assert response.body == b""


# Ensure headers object is preserved
@pytest.mark.parametrize("browser_monitoring", [True, False])
@validate_transaction_metrics(name="", group="Uri")
def test_generator_headers(browser_monitoring):
"""
Both ASGIApplicationWrapper and ASGIBrowserMiddleware can cause headers to be lost if generators are
not handled properly.

Ensure neither destroys headers by testing with and without the ASGIBrowserMiddleware, to make sure whichever
receives headers first properly preserves them in a list.
"""

@override_application_settings({"browser_monitoring.enabled": browser_monitoring})
def _test():
response = asgi_application_generator_headers.make_request("GET", "/")
assert response.status == 200
assert response.headers == {"x-my-header": "myvalue"}
assert response.body == b""

_test()


# Test asgi_application decorator with parameters passed in on a single callable
@pytest.mark.parametrize("name, group", ((None, "group"), ("name", "group"), ("", "group")))
def test_asgi_application_decorator_single_callable(name, group):
Expand Down
2 changes: 1 addition & 1 deletion tests/testing_support/asgi_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def process_output(self):
if self.response_state is ResponseState.NOT_STARTED:
assert message["type"] == "http.response.start"
response_status = message["status"]
response_headers = message.get("headers", response_headers)
response_headers = list(message.get("headers", response_headers))
self.response_state = ResponseState.BODY
elif self.response_state is ResponseState.BODY:
assert message["type"] == "http.response.body"
Expand Down
17 changes: 17 additions & 0 deletions tests/testing_support/sample_asgi_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ async def normal_asgi_application(scope, receive, send):
await send({"type": "http.response.body", "body": output})


@ASGIApplicationWrapper
async def asgi_application_generator_headers(scope, receive, send):
if scope["type"] == "lifespan":
return await handle_lifespan(scope, receive, send)

if scope["type"] != "http":
raise ValueError("unsupported")

def headers():
yield (b"x-my-header", b"myvalue")

await send({"type": "http.response.start", "status": 200, "headers": headers()})
await send({"type": "http.response.body"})

assert current_transaction() is None


async def handle_lifespan(scope, receive, send):
"""Handle lifespan protocol with no-ops to allow more compatibility."""
while True:
Expand Down
Loading