diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index cf0a3784fe8c..5eb908d62bcb 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -122,6 +122,7 @@ from sglang.srt.utils import ( add_api_key_middleware, add_prometheus_middleware, + add_prometheus_track_response_middleware, delete_directory, get_bool_env_var, kill_process_tree, @@ -1397,6 +1398,9 @@ def launch_server( ) ) + if server_args.enable_metrics: + add_prometheus_track_response_middleware(app) + # Pass additional arguments to the lifespan function. # They will be used for additional initialization setups. if server_args.tokenizer_worker_num == 1: diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index cf7264225dc6..4723c555d753 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1504,6 +1504,31 @@ def add_prometheus_middleware(app): app.routes.append(metrics_route) +def add_prometheus_track_response_middleware(app): + from prometheus_client import Counter + + http_response_status_counter = Counter( + name="sglang:http_responses_total", + documentation="Total number of HTTP responses by endpoint and status code", + labelnames=["endpoint", "status_code", "method"], + ) + + @app.middleware("http") + async def track_http_status_code(request, call_next): + response = await call_next(request) + + route = request.scope.get("route") + endpoint = route.path if route else "Unknown" + + http_response_status_counter.labels( + endpoint=endpoint, + status_code=str(response.status_code), + method=request.method, + ).inc() + + return response + + def bind_port(port): """Bind to a specific port, assuming it's available.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)