diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index ae81acb7f4f4..acf60ccfe7b6 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -42,7 +42,6 @@ import numpy as np -import orjson import requests import uvicorn import uvloop @@ -173,6 +172,11 @@ set_uvicorn_logging_configs, ) from sglang.srt.utils.auth import AuthLevel, app_has_admin_force_endpoints, auth_level +from sglang.srt.utils.json_response import ( + SGLangORJSONResponse, + dumps_json, + orjson_response, +) from sglang.utils import get_exception_traceback from sglang.version import __version__ @@ -666,7 +670,11 @@ async def _dumper_control_handler(method: str, request: Request): # fastapi implicitly converts json in the request to obj (dataclass) -@app.api_route("/generate", methods=["POST", "PUT"]) +@app.api_route( + "/generate", + methods=["POST", "PUT"], + response_class=SGLangORJSONResponse, +) async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: @@ -676,15 +684,11 @@ async def stream_results() -> AsyncIterator[bytes]: async for out in _global_state.tokenizer_manager.generate_request( obj, request ): - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY - ) + b"\n\n" + yield b"data: " + dumps_json(out) + b"\n\n" except ValueError as e: out = {"error": {"message": str(e)}} logger.error(f"[http_server] Error: {e}") - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY - ) + b"\n\n" + yield b"data: " + dumps_json(out) + b"\n\n" yield b"data: [DONE]\n\n" return StreamingResponse( @@ -697,12 +701,7 @@ async def stream_results() -> AsyncIterator[bytes]: ret = await _global_state.tokenizer_manager.generate_request( obj, request ).__anext__() - return Response( - content=orjson.dumps( - ret, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY - ), - media_type="application/json", - ) + return orjson_response(ret) except ValueError as e: logger.error(f"[http_server] Error: {e}") return _create_error_response(e) diff --git a/python/sglang/srt/utils/json_response.py b/python/sglang/srt/utils/json_response.py new file mode 100644 index 000000000000..03130c61a4e8 --- /dev/null +++ b/python/sglang/srt/utils/json_response.py @@ -0,0 +1,32 @@ +"""Utilities for JSON serialization in HTTP responses.""" + +from typing import Any + +import orjson +from fastapi.responses import ORJSONResponse, Response + +# Keep response serialization behavior consistent across endpoints: +# - Support non-string dictionary keys used in some metadata payloads. +# - Support numpy scalars/arrays without pre-conversion. +ORJSON_RESPONSE_OPTIONS = orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY + + +def dumps_json(content: Any) -> bytes: + """Serialize content to JSON bytes using SGLang's ORJSON options.""" + return orjson.dumps(content, option=ORJSON_RESPONSE_OPTIONS) + + +class SGLangORJSONResponse(ORJSONResponse): + """ORJSON response with SGLang-specific serialization options.""" + + def render(self, content: Any) -> bytes: + return dumps_json(content) + + +def orjson_response(content: Any, status_code: int = 200) -> Response: + """Create a JSON response with stable ORJSON serialization options.""" + return Response( + content=dumps_json(content), + media_type="application/json", + status_code=status_code, + ) diff --git a/test/registered/unit/utils/test_json_response.py b/test/registered/unit/utils/test_json_response.py new file mode 100644 index 000000000000..20fbc8108fca --- /dev/null +++ b/test/registered/unit/utils/test_json_response.py @@ -0,0 +1,55 @@ +import unittest + +import numpy as np +import orjson + +from sglang.srt.utils.json_response import ( + SGLangORJSONResponse, + dumps_json, + orjson_response, +) +from sglang.test.ci.ci_register import register_cpu_ci + +register_cpu_ci(est_time=2, suite="stage-a-cpu-only") + + +class TestJSONResponseUtils(unittest.TestCase): + def test_dumps_json_maps_non_finite_values_to_null(self): + payload = { + "neg_inf": float("-inf"), + "pos_inf": float("inf"), + "nan": float("nan"), + } + parsed = orjson.loads(dumps_json(payload)) + + self.assertIsNone(parsed["neg_inf"]) + self.assertIsNone(parsed["pos_inf"]) + self.assertIsNone(parsed["nan"]) + + def test_dumps_json_supports_numpy_and_non_string_keys(self): + payload = { + 1: np.array([1, 2, 3], dtype=np.int64), + "scalar": np.float32(1.5), + } + parsed = orjson.loads(dumps_json(payload)) + + self.assertEqual(parsed["1"], [1, 2, 3]) + self.assertAlmostEqual(parsed["scalar"], 1.5) + + def test_orjson_response_uses_expected_media_type(self): + response = orjson_response({"value": float("-inf")}, status_code=201) + parsed = orjson.loads(response.body) + + self.assertEqual(response.status_code, 201) + self.assertEqual(response.media_type, "application/json") + self.assertIsNone(parsed["value"]) + + def test_sglang_orjson_response_serializes_with_shared_options(self): + response = SGLangORJSONResponse(content={"value": float("-inf")}) + parsed = orjson.loads(response.body) + + self.assertIsNone(parsed["value"]) + + +if __name__ == "__main__": + unittest.main()