Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/sglang/srt/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ async def v1_score_request(raw_request: Request):
pass


@app.api_route("/v1/models/{model_id}", methods=["GET"])
async def show_model_detail(model_id: str):
served_model_name = app.state.tokenizer_manager.served_model_name

return ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)


# Additional API endpoints will be implemented in separate serving_*.py modules
# and mounted as APIRouters in future PRs

Expand Down
27 changes: 0 additions & 27 deletions python/sglang/srt/entrypoints/openai/serving_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,33 +114,6 @@ def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
"""Validate request"""
pass

def _calculate_streaming_usage_base(
self,
prompt_tokens: Dict[int, int],
completion_tokens: Dict[int, int],
cached_tokens: Dict[int, int],
n_choices: int,
) -> UsageInfo:
"""Calculate usage information for streaming responses (common logic)"""
total_prompt_tokens = sum(
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
)
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())

cache_report = self.tokenizer_manager.server_args.enable_cache_report
prompt_tokens_details = None
if cache_report:
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
if cached_tokens_sum > 0:
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}

return UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details,
)

def create_error_response(
self,
message: str,
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
TopLogprob,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
aggregate_token_usage,
detect_template_content_format,
process_content_for_template_format,
to_openai_style_logprobs,
Expand Down Expand Up @@ -546,11 +546,12 @@ async def _generate_chat_stream(

# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
usage = UsageProcessor.calculate_streaming_usage(
prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
Expand Down Expand Up @@ -658,7 +659,9 @@ def _build_chat_response(

# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report)
usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
)

return ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
Expand Down
15 changes: 8 additions & 7 deletions python/sglang/srt/entrypoints/openai/serving_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.utils import (
aggregate_token_usage,
to_openai_style_logprobs,
)
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
from sglang.srt.managers.io_struct import GenerateReqInput

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -214,11 +212,12 @@ async def _generate_completion_stream(

# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
usage = UsageProcessor.calculate_streaming_usage(
prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
Expand Down Expand Up @@ -322,7 +321,9 @@ def _build_completion_response(

# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report)
usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
)

return CompletionResponse(
id=ret[0]["meta_info"]["id"],
Expand Down
81 changes: 81 additions & 0 deletions python/sglang/srt/entrypoints/openai/usage_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, final

from python.sglang.srt.entrypoints.openai.protocol import UsageInfo


@final
class UsageProcessor:
"""Stateless helpers that turn raw token counts into a UsageInfo."""

@staticmethod
def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
"""Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
return {"cached_tokens": count} if count > 0 else None

@staticmethod
def calculate_response_usage(
responses: List[Dict[str, Any]],
n_choices: int = 1,
enable_cache_report: bool = False,
) -> UsageInfo:
completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)

prompt_tokens = sum(
responses[i]["meta_info"]["prompt_tokens"]
for i in range(0, len(responses), n_choices)
)

cached_details = None
if enable_cache_report:
cached_total = sum(
r["meta_info"].get("cached_tokens", 0) for r in responses
)
cached_details = UsageProcessor._details_if_cached(cached_total)

return UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_details,
)

@staticmethod
def calculate_streaming_usage(
prompt_tokens: Mapping[int, int],
completion_tokens: Mapping[int, int],
cached_tokens: Mapping[int, int],
n_choices: int,
enable_cache_report: bool = False,
) -> UsageInfo:
# index % n_choices == 0 marks the first choice of a prompt
total_prompt_tokens = sum(
tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
)
total_completion_tokens = sum(completion_tokens.values())

cached_details = (
UsageProcessor._details_if_cached(sum(cached_tokens.values()))
if enable_cache_report
else None
)

return UsageProcessor.calculate_token_usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
cached_tokens=cached_details,
)

@staticmethod
def calculate_token_usage(
prompt_tokens: int,
completion_tokens: int,
cached_tokens: Optional[Dict[str, int]] = None,
) -> UsageInfo:
"""Calculate token usage information"""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=cached_tokens,
)
59 changes: 1 addition & 58 deletions python/sglang/srt/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from typing import Any, Dict, List, Optional

import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils

from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo
from sglang.srt.entrypoints.openai.protocol import LogProbs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -171,62 +170,6 @@ def process_content_for_template_format(
return new_msg


def calculate_token_usage(
prompt_tokens: int,
completion_tokens: int,
cached_tokens: Optional[Dict[str, int]] = None,
) -> UsageInfo:
"""Calculate token usage information"""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=cached_tokens,
)


def aggregate_token_usage(
responses: List[Dict[str, Any]],
n_choices: int = 1,
enable_cache_report: bool = False,
) -> UsageInfo:
"""Aggregate token usage from multiple responses

Args:
responses: List of response dictionaries with meta_info
n_choices: Number of choices per request (for prompt token counting)
enable_cache_report: Whether to include cached token details

Returns:
Aggregated UsageInfo
"""
# Sum completion tokens from all responses
completion_tokens = sum(
response["meta_info"]["completion_tokens"] for response in responses
)

# For prompt tokens, only count every n_choices-th response to avoid double counting
prompt_tokens = sum(
responses[i]["meta_info"]["prompt_tokens"]
for i in range(0, len(responses), n_choices)
)

# Handle cached tokens if cache reporting is enabled
cached_tokens_details = None
if enable_cache_report:
cached_tokens_sum = sum(
response["meta_info"].get("cached_tokens", 0) for response in responses
)
if cached_tokens_sum > 0:
cached_tokens_details = {"cached_tokens": cached_tokens_sum}

return calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_tokens_details,
)


def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
Expand Down
Loading