From b5dc70d1f5279e759f81004755342f493764ea31 Mon Sep 17 00:00:00 2001 From: Martynov Maxim Date: Sat, 23 Dec 2023 20:37:34 +0300 Subject: [PATCH] Fix handling root_path --- starlette_exporter/middleware.py | 144 ++++++++++---------- tests/test_middleware.py | 217 +++++++++++++++++++++++++------ 2 files changed, 249 insertions(+), 112 deletions(-) diff --git a/starlette_exporter/middleware.py b/starlette_exporter/middleware.py index 61cc109..aa7afeb 100644 --- a/starlette_exporter/middleware.py +++ b/starlette_exporter/middleware.py @@ -1,26 +1,26 @@ """ Middleware for exporting Prometheus metrics using Starlette """ -from collections import OrderedDict -import time import logging +import time import warnings +from collections import OrderedDict from inspect import iscoroutine from typing import ( Any, Callable, + ClassVar, + Dict, List, Mapping, Optional, - ClassVar, - Dict, - Union, Sequence, + Union, ) -from prometheus_client import Counter, Histogram, Gauge +from prometheus_client import Counter, Gauge, Histogram from prometheus_client.metrics import MetricWrapperBase from starlette.requests import Request from starlette.routing import BaseRoute, Match -from starlette.types import ASGIApp, Message, Receive, Send, Scope +from starlette.types import ASGIApp, Message, Receive, Scope, Send from . import optional_metrics @@ -324,73 +324,78 @@ async def wrapped_send(message: Message) -> None: await send(message) + exception: Optional[Exception] = None try: await self.app(scope, receive, wrapped_send) - except Exception: + except Exception as e: status_code = 500 - raise - finally: - # Decrement 'requests_in_progress' gauge after response sent - self.requests_in_progress.labels( - method, self.app_name, *default_labels - ).dec() - - if self.filter_unhandled_paths or self.group_paths: - grouped_path = self._get_router_path(scope) - - # filter_unhandled_paths removes any requests without mapped endpoint from the metrics. - if self.filter_unhandled_paths and grouped_path is None: - return - - # group_paths enables returning the original router path (with url param names) - # for example, when using this option, requests to /api/product/1 and /api/product/3 - # will both be grouped under /api/product/{product_id}. See the README for more info. - if self.group_paths and grouped_path is not None: - path = grouped_path - - if status_code is None: - request = Request(scope, receive) - if await request.is_disconnected(): - # In case no response was returned and the client is disconnected, 499 is reported as status code. - status_code = 499 - else: - status_code = 500 - - labels = [method, path, status_code, self.app_name, *default_labels] - - # optional extra arguments to be passed as kwargs to observations - # note: only used for histogram observations and counters to support exemplars - extra = {} - if self.exemplars: - extra["exemplar"] = self.exemplars() - - # optional response body size metric - if ( - self.optional_metrics_list is not None - and optional_metrics.response_body_size in self.optional_metrics_list - and self.response_body_size_count is not None - ): - self.response_body_size_count.labels(*labels).inc( - response_body_size, **extra - ) + exception = e + + # Decrement 'requests_in_progress' gauge after response sent + self.requests_in_progress.labels( + method, self.app_name, *default_labels + ).dec() + + if self.filter_unhandled_paths or self.group_paths: + grouped_path = self._get_router_path(scope) + + # filter_unhandled_paths removes any requests without mapped endpoint from the metrics. + if self.filter_unhandled_paths and grouped_path is None: + if exception: + raise exception + return + + # group_paths enables returning the original router path (with url param names) + # for example, when using this option, requests to /api/product/1 and /api/product/3 + # will both be grouped under /api/product/{product_id}. See the README for more info. + if self.group_paths and grouped_path is not None: + path = grouped_path + + if status_code is None: + if await request.is_disconnected(): + # In case no response was returned and the client is disconnected, 499 is reported as status code. + status_code = 499 + else: + status_code = 500 + + labels = [method, path, status_code, self.app_name, *default_labels] + + # optional extra arguments to be passed as kwargs to observations + # note: only used for histogram observations and counters to support exemplars + extra = {} + if self.exemplars: + extra["exemplar"] = self.exemplars() + + # optional response body size metric + if ( + self.optional_metrics_list is not None + and optional_metrics.response_body_size in self.optional_metrics_list + and self.response_body_size_count is not None + ): + self.response_body_size_count.labels(*labels).inc( + response_body_size, **extra + ) - # optional request body size metric - if ( - self.optional_metrics_list is not None - and optional_metrics.request_body_size in self.optional_metrics_list - and self.request_body_size_count is not None - ): - self.request_body_size_count.labels(*labels).inc( - request_body_size, **extra - ) + # optional request body size metric + if ( + self.optional_metrics_list is not None + and optional_metrics.request_body_size in self.optional_metrics_list + and self.request_body_size_count is not None + ): + self.request_body_size_count.labels(*labels).inc( + request_body_size, **extra + ) - # if we were not able to set end when the response body was written, - # set it now. - if end is None: - end = time.perf_counter() + # if we were not able to set end when the response body was written, + # set it now. + if end is None: + end = time.perf_counter() + + self.request_count.labels(*labels).inc(**extra) + self.request_time.labels(*labels).observe(end - begin, **extra) - self.request_count.labels(*labels).inc(**extra) - self.request_time.labels(*labels).observe(end - begin, **extra) + if exception: + raise exception @staticmethod def _get_router_path(scope: Scope) -> Optional[str]: @@ -403,10 +408,11 @@ def _get_router_path(scope: Scope) -> Optional[str]: if hasattr(app, "root_path"): app_root_path = getattr(app, "root_path") - if root_path.startswith(app_root_path): + if app_root_path and root_path.startswith(app_root_path): root_path = root_path[len(app_root_path) :] base_scope = { + "root_path": root_path, "type": scope.get("type"), "path": root_path + scope.get("path", ""), "path_params": scope.get("path_params", {}), diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 52b78f3..539c480 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -3,6 +3,7 @@ import pytest from prometheus_client import REGISTRY +from starlette import __version__ as starlette_version from starlette.applications import Starlette from starlette.background import BackgroundTask from starlette.exceptions import HTTPException @@ -14,11 +15,11 @@ import starlette_exporter from starlette_exporter import ( PrometheusMiddleware, - handle_metrics, from_header, + handle_metrics, handle_openmetrics, ) -from starlette_exporter.optional_metrics import response_body_size, request_body_size +from starlette_exporter.optional_metrics import request_body_size, response_body_size @pytest.fixture @@ -138,19 +139,101 @@ def test_500(self, client): in metrics ) - def test_unhandled(self, client): - """test that an unhandled exception still gets logged in the requests counter""" + def test_404(self, client): + """test that an unknown path is handled properly""" + client.get("/404") + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/404",status_code="404"} 1.0""" + in metrics + ) + + def test_404_filter(self, testapp): + """test that a unknown path can be excluded from metrics""" + client = TestClient(testapp(filter_unhandled_paths=True)) + try: - client.get("/unhandled") + client.get("/404") except: pass metrics = client.get("/metrics").content.decode() + assert "404" not in metrics + + def test_unhandled(self, client): + """test that an unhandled exception still gets logged in the requests counter""" + + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled") + + metrics = client.get("/metrics").content.decode() + assert ( """starlette_requests_total{app_name="starlette",method="GET",path="/unhandled",status_code="500"} 1.0""" in metrics ) + def test_ungrouped_paths(self, client): + """test that an endpoints parameters with group_paths=False are added to metrics""" + + client.get("/200/111") + client.get("/500/1111") + client.get("/404/11111") + + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled/123") + + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/200/111",status_code="200"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/500/1111",status_code="500"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/404/11111",status_code="404"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/unhandled/123",status_code="500"} 1.0""" + in metrics + ) + + def test_custom_root_path(self, testapp): + """test that an unhandled exception still gets logged in the requests counter""" + + client = TestClient(testapp(), root_path="/api") + + client.get("/200") + client.get("/500") + client.get("/404") + + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled") + + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/api/200",status_code="200"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/api/500",status_code="500"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/api/404",status_code="404"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/api/unhandled",status_code="500"} 1.0""" + in metrics + ) + def test_histogram(self, client): """test that histogram buckets appear after making requests""" @@ -214,14 +297,6 @@ def test_app_name(self, testapp): in metrics ) - def test_filter_unhandled_paths(self, testapp): - """test that app_name label is populated correctly""" - client = TestClient(testapp(filter_unhandled_paths=True)) - - client.get("/this_path_does_not_exist") - metrics = client.get("/metrics").content.decode() - assert "this_path_does_not_exist" not in metrics - def test_mounted_path(self, testapp): """test that mounted paths appear even when filter_unhandled_paths is True""" client = TestClient(testapp(filter_unhandled_paths=True)) @@ -244,23 +319,23 @@ def test_mounted_path_with_param(self, testapp): in metrics ) - def test_mounted_path_unhandled(self, testapp): + def test_mounted_path_404(self, client): """test an unhandled path that will be partially matched at the mounted base path""" - client = TestClient(testapp(filter_unhandled_paths=True)) - client.get("/mounted/unhandled/123") + client.get("/mounted/404") metrics = client.get("/metrics").content.decode() - assert """path="/mounted/unhandled""" not in metrics - assert """path="/mounted""" not in metrics + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/mounted/404",status_code="404"} 1.0""" + in metrics + ) - def test_mounted_path_unhandled_grouped(self, testapp): - """test an unhandled path that will be partially matched at the mounted base path (grouped paths)""" - client = TestClient(testapp(filter_unhandled_paths=True, group_paths=True)) - client.get("/mounted/unhandled/123") + def test_mounted_path_404_filter(self, testapp): + """test an unhandled path from mounted base path can be excluded from metrics""" + client = TestClient(testapp(filter_unhandled_paths=True)) + client.get("/mounted/404") metrics = client.get("/metrics").content.decode() - assert """path="/mounted/unhandled""" not in metrics - assert """path="/mounted""" not in metrics + assert "/mounted" not in metrics def test_staticfiles_path(self, testapp): """test a static file path""" @@ -394,48 +469,104 @@ def test_500(self, client): in metrics ) - def test_unhandled(self, client): - """test that an unhandled exception still gets logged in the requests counter""" + def test_404(self, client): + """test that a 404 is handled properly, even though the path won't be grouped""" try: - client.get("/unhandled/11111") + client.get("/404/11111") except: pass metrics = client.get("/metrics").content.decode() + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/404/11111",status_code="404"} 1.0""" + in metrics + ) + + def test_unhandled(self, client): + """test that an unhandled exception still gets logged in the requests counter (grouped paths)""" + + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled/123") + + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/unhandled/{test_param}",status_code="500"} 1.0""" + in metrics + ) + + def test_custom_root_path(self, testapp): + """test that custom root_path does not affect the path grouping""" + + client = TestClient(testapp(group_paths=True, filter_unhandled_paths=True), root_path="/api") + + client.get("/200/111") + client.get("/500/1111") + client.get("/404/123") + + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled/123") + + metrics = client.get("/metrics").content.decode() + + starlette_version_tuple = tuple(map(int, starlette_version.split("."))) + if starlette_version_tuple < (0, 33): + # These asserts are valid only on Starlette 0.33+ + # See https://github.com/encode/starlette/pull/2352" + return + + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/200/{test_param}",status_code="200"} 1.0""" + in metrics + ) + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/500/{test_param}",status_code="500"} 1.0""" + in metrics + ) assert ( """starlette_requests_total{app_name="starlette",method="GET",path="/unhandled/{test_param}",status_code="500"} 1.0""" in metrics ) + assert "404" not in metrics + + def test_mounted_path_404(self, testapp): + """test an unhandled path that will be partially matched at the mounted base path (grouped paths)""" + client = TestClient(testapp(group_paths=True)) + client.get("/mounted/404") + metrics = client.get("/metrics").content.decode() + + assert ( + """starlette_requests_total{app_name="starlette",method="GET",path="/mounted/404",status_code="404"} 1.0""" + in metrics + ) + + def test_mounted_path_404_filter(self, testapp): + """test an unhandled path from mounted base path can be excluded from metrics (grouped paths)""" + client = TestClient(testapp(group_paths=True, filter_unhandled_paths=True)) + client.get("/mounted/404") + metrics = client.get("/metrics").content.decode() + + assert "/mounted" not in metrics def test_staticfiles_path(self, testapp): """test a static file path, with group_paths=True""" client = TestClient(testapp(filter_unhandled_paths=True, group_paths=True)) client.get("/static/test.txt") metrics = client.get("/metrics").content.decode() - assert 'path="/static"' in metrics - - def test_404(self, client): - """test that a 404 is handled properly, even though the path won't be matched""" - try: - client.get("/not_found/11111") - except: - pass - metrics = client.get("/metrics").content.decode() assert ( - """starlette_requests_total{app_name="starlette",method="GET",path="/not_found/11111",status_code="404"} 1.0""" + """starlette_requests_total{app_name="starlette",method="GET",path="/static",status_code="200"} 1.0""" in metrics ) def test_histogram(self, client): """test that histogram buckets appear after making requests""" - client.get("/200/1") - client.get("/500/12") - try: - client.get("/unhandled/111") - except: - pass + client.get("/200/111") + client.get("/500/1111") + + with pytest.raises(KeyError, match="value_error"): + client.get("/unhandled/123") metrics = client.get("/metrics").content.decode()