Skip to content

Commit

Permalink
WSGI request/response hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
owais committed Apr 8, 2021
1 parent ebfd098 commit cc6260b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 32 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `opentelemetry-instrumentation-urllib3` Add urllib3 instrumentation
([#299](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/299))
- `opentelemetry-instrumentation-wsgi` Replaced `name_callback` with `request_hook`
and `response_hook` callbacks.
([#424](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/424))

## [0.19b0](https://github.com/open-telemetry/opentelemetry-python-contrib/releases/tag/v0.19b0) - 2021-03-26

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,6 @@ def add_response_attributes(
span.set_status(Status(http_status_to_status_code(status_code)))


def get_default_span_name(environ):
"""Default implementation for name_callback, returns HTTP {METHOD_NAME}."""
return "HTTP {}".format(environ.get("REQUEST_METHOD", "")).strip()


class OpenTelemetryMiddleware:
"""The WSGI application middleware.
Expand All @@ -187,21 +182,26 @@ class OpenTelemetryMiddleware:
Args:
wsgi: The WSGI application callable to forward requests to.
name_callback: Callback which calculates a generic span name for an
incoming HTTP request based on the PEP3333 WSGI environ.
Optional: Defaults to get_default_span_name.
request_hook: Optional callback which is called with the server span and WSGI
environ object for every incoming request.
response_hook: Optional callback which is called with the server span,
WSGI environ, status_code and response_headers for every
incoming request.
"""

def __init__(self, wsgi, name_callback=get_default_span_name):
def __init__(self, wsgi, request_hook=None, response_hook=None):
self.wsgi = wsgi
self.tracer = trace.get_tracer(__name__, __version__)
self.name_callback = name_callback
self.request_hook = request_hook
self.response_hook = response_hook

@staticmethod
def _create_start_response(span, start_response):
def _create_start_response(span, start_response, response_hook):
@functools.wraps(start_response)
def _start_response(status, response_headers, *args, **kwargs):
add_response_attributes(span, status, response_headers)
if response_hook:
response_hook(status, response_headers)
return start_response(status, response_headers, *args, **kwargs)

return _start_response
Expand All @@ -215,18 +215,24 @@ def __call__(self, environ, start_response):
"""

token = context.attach(extract(environ, getter=wsgi_getter))
span_name = self.name_callback(environ)

span = self.tracer.start_span(
span_name,
"HTTP {}".format(environ.get("REQUEST_METHOD", "")).strip(),
kind=trace.SpanKind.SERVER,
attributes=collect_request_attributes(environ),
)

if self.request_hook:
self.request_hook(span, environ)

response_hook = self.response_hook
if response_hook:
response_hook = functools.partial(response_hook, span, environ)

try:
with trace.use_span(span):
start_response = self._create_start_response(
span, start_response
span, start_response, response_hook
)
iterable = self.wsgi(environ, start_response)
return _end_span_after_iterating(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ def error_wsgi_unhandled(environ, start_response):

class TestWsgiApplication(WsgiTestBase):
def validate_response(
self, response, error=None, span_name="HTTP GET", http_method="GET"
self,
response,
error=None,
span_name="HTTP GET",
http_method="GET",
span_attributes=None,
response_headers=None,
):
while True:
try:
Expand All @@ -90,10 +96,12 @@ def validate_response(
except StopIteration:
break

expected_headers = [("Content-Type", "text/plain")]
if response_headers:
expected_headers.extend(response_headers)

self.assertEqual(self.status, "200 OK")
self.assertEqual(
self.response_headers, [("Content-Type", "text/plain")]
)
self.assertEqual(self.response_headers, expected_headers)
if error:
self.assertIs(self.exc_info[0], error)
self.assertIsInstance(self.exc_info[1], error)
Expand All @@ -115,6 +123,7 @@ def validate_response(
"http.status_text": "OK",
"http.status_code": 200,
}
expected_attributes.update(span_attributes or {})
if http_method is not None:
expected_attributes["http.method"] = http_method
self.assertEqual(span_list[0].attributes, expected_attributes)
Expand All @@ -124,6 +133,30 @@ def test_basic_wsgi_call(self):
response = app(self.environ, self.start_response)
self.validate_response(response)

def test_hooks(self):
hook_headers = (
"hook_attr",
"hello otel",
)

def request_hook(span, environ):
span.update_name("name from hook")

def response_hook(span, environ, status_code, response_headers):
span.set_attribute("hook_attr", "hello world")
response_headers.append(hook_headers)

app = otel_wsgi.OpenTelemetryMiddleware(
simple_wsgi, request_hook, response_hook
)
response = app(self.environ, self.start_response)
self.validate_response(
response,
span_name="name from hook",
span_attributes={"hook_attr": "hello world"},
response_headers=(hook_headers,),
)

def test_wsgi_not_recording(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
Expand Down Expand Up @@ -177,20 +210,6 @@ def test_wsgi_internal_error(self):
span_list[0].status.status_code, StatusCode.ERROR,
)

def test_override_span_name(self):
"""Test that span_names can be overwritten by our callback function."""
span_name = "Dymaxion"

def get_predefined_span_name(scope):
# pylint: disable=unused-argument
return span_name

app = otel_wsgi.OpenTelemetryMiddleware(
simple_wsgi, name_callback=get_predefined_span_name
)
response = app(self.environ, self.start_response)
self.validate_response(response, span_name=span_name)

def test_default_span_name_missing_request_method(self):
"""Test that default span_names with missing request method."""
self.environ.pop("REQUEST_METHOD")
Expand Down

0 comments on commit cc6260b

Please sign in to comment.