Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@


import numpy as np
import orjson
import requests
import uvicorn
import uvloop
Expand Down Expand Up @@ -173,6 +172,11 @@
set_uvicorn_logging_configs,
)
from sglang.srt.utils.auth import AuthLevel, app_has_admin_force_endpoints, auth_level
from sglang.srt.utils.json_response import (
SGLangORJSONResponse,
dumps_json,
orjson_response,
)
from sglang.utils import get_exception_traceback
from sglang.version import __version__

Expand Down Expand Up @@ -666,7 +670,11 @@ async def _dumper_control_handler(method: str, request: Request):


# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@app.api_route(
"/generate",
methods=["POST", "PUT"],
response_class=SGLangORJSONResponse,
)
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
Expand All @@ -676,15 +684,11 @@ async def stream_results() -> AsyncIterator[bytes]:
async for out in _global_state.tokenizer_manager.generate_request(
obj, request
):
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY
) + b"\n\n"
yield b"data: " + dumps_json(out) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
logger.error(f"[http_server] Error: {e}")
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY
) + b"\n\n"
yield b"data: " + dumps_json(out) + b"\n\n"
yield b"data: [DONE]\n\n"

return StreamingResponse(
Expand All @@ -697,12 +701,7 @@ async def stream_results() -> AsyncIterator[bytes]:
ret = await _global_state.tokenizer_manager.generate_request(
obj, request
).__anext__()
return Response(
content=orjson.dumps(
ret, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY
),
media_type="application/json",
)
return orjson_response(ret)
except ValueError as e:
logger.error(f"[http_server] Error: {e}")
return _create_error_response(e)
Expand Down
32 changes: 32 additions & 0 deletions python/sglang/srt/utils/json_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Utilities for JSON serialization in HTTP responses."""

from typing import Any

import orjson
from fastapi.responses import ORJSONResponse, Response

# Keep response serialization behavior consistent across endpoints:
# - Support non-string dictionary keys used in some metadata payloads.
# - Support numpy scalars/arrays without pre-conversion.
ORJSON_RESPONSE_OPTIONS = orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY


def dumps_json(content: Any) -> bytes:
"""Serialize content to JSON bytes using SGLang's ORJSON options."""
return orjson.dumps(content, option=ORJSON_RESPONSE_OPTIONS)


class SGLangORJSONResponse(ORJSONResponse):
"""ORJSON response with SGLang-specific serialization options."""

def render(self, content: Any) -> bytes:
return dumps_json(content)


def orjson_response(content: Any, status_code: int = 200) -> Response:
"""Create a JSON response with stable ORJSON serialization options."""
return Response(
content=dumps_json(content),
media_type="application/json",
status_code=status_code,
)
55 changes: 55 additions & 0 deletions test/registered/unit/utils/test_json_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest

import numpy as np
import orjson

from sglang.srt.utils.json_response import (
SGLangORJSONResponse,
dumps_json,
orjson_response,
)
from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(est_time=2, suite="stage-a-cpu-only")


class TestJSONResponseUtils(unittest.TestCase):
def test_dumps_json_maps_non_finite_values_to_null(self):
payload = {
"neg_inf": float("-inf"),
"pos_inf": float("inf"),
"nan": float("nan"),
}
parsed = orjson.loads(dumps_json(payload))

self.assertIsNone(parsed["neg_inf"])
self.assertIsNone(parsed["pos_inf"])
self.assertIsNone(parsed["nan"])

def test_dumps_json_supports_numpy_and_non_string_keys(self):
payload = {
1: np.array([1, 2, 3], dtype=np.int64),
"scalar": np.float32(1.5),
}
parsed = orjson.loads(dumps_json(payload))

self.assertEqual(parsed["1"], [1, 2, 3])
self.assertAlmostEqual(parsed["scalar"], 1.5)

def test_orjson_response_uses_expected_media_type(self):
response = orjson_response({"value": float("-inf")}, status_code=201)
parsed = orjson.loads(response.body)

self.assertEqual(response.status_code, 201)
self.assertEqual(response.media_type, "application/json")
self.assertIsNone(parsed["value"])

def test_sglang_orjson_response_serializes_with_shared_options(self):
response = SGLangORJSONResponse(content={"value": float("-inf")})
parsed = orjson.loads(response.body)

self.assertIsNone(parsed["value"])


if __name__ == "__main__":
unittest.main()
Loading