diff --git a/python/sglang/srt/entrypoints/openai/api_server.py b/python/sglang/srt/entrypoints/openai/api_server.py index b575275aec2..a3164339563 100644 --- a/python/sglang/srt/entrypoints/openai/api_server.py +++ b/python/sglang/srt/entrypoints/openai/api_server.py @@ -192,6 +192,17 @@ async def v1_score_request(raw_request: Request): pass +@app.api_route("/v1/models/{model_id}", methods=["GET"]) +async def show_model_detail(model_id: str): + served_model_name = app.state.tokenizer_manager.served_model_name + + return ModelCard( + id=served_model_name, + root=served_model_name, + max_model_len=app.state.tokenizer_manager.model_config.context_len, + ) + + # Additional API endpoints will be implemented in separate serving_*.py modules # and mounted as APIRouters in future PRs diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 7d26d1707a2..8e22c26c485 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -114,33 +114,6 @@ def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]: """Validate request""" pass - def _calculate_streaming_usage_base( - self, - prompt_tokens: Dict[int, int], - completion_tokens: Dict[int, int], - cached_tokens: Dict[int, int], - n_choices: int, - ) -> UsageInfo: - """Calculate usage information for streaming responses (common logic)""" - total_prompt_tokens = sum( - tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0 - ) - total_completion_tokens = sum(tokens for tokens in completion_tokens.values()) - - cache_report = self.tokenizer_manager.server_args.enable_cache_report - prompt_tokens_details = None - if cache_report: - cached_tokens_sum = sum(tokens for tokens in cached_tokens.values()) - if cached_tokens_sum > 0: - prompt_tokens_details = {"cached_tokens": cached_tokens_sum} - - return UsageInfo( - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens, - prompt_tokens_details=prompt_tokens_details, - ) - def create_error_response( self, message: str, diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 0465b59e9ce..98e622819e3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -26,8 +26,8 @@ TopLogprob, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor from sglang.srt.entrypoints.openai.utils import ( - aggregate_token_usage, detect_template_content_format, process_content_for_template_format, to_openai_style_logprobs, @@ -546,11 +546,12 @@ async def _generate_chat_stream( # Additional usage chunk if request.stream_options and request.stream_options.include_usage: - usage = self._calculate_streaming_usage_base( + usage = UsageProcessor.calculate_streaming_usage( prompt_tokens, completion_tokens, cached_tokens, - request.n, + n_choices=request.n, + enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report, ) usage_chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -658,7 +659,9 @@ def _build_chat_response( # Calculate usage cache_report = self.tokenizer_manager.server_args.enable_cache_report - usage = aggregate_token_usage(ret, request.n, cache_report) + usage = UsageProcessor.calculate_response_usage( + ret, n_choices=request.n, enable_cache_report=cache_report + ) return ChatCompletionResponse( id=ret[0]["meta_info"]["id"], diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 20725987bc2..eea6dbccc1b 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -18,10 +18,8 @@ ErrorResponse, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase -from sglang.srt.entrypoints.openai.utils import ( - aggregate_token_usage, - to_openai_style_logprobs, -) +from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor +from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs from sglang.srt.managers.io_struct import GenerateReqInput logger = logging.getLogger(__name__) @@ -214,11 +212,12 @@ async def _generate_completion_stream( # Handle final usage chunk if request.stream_options and request.stream_options.include_usage: - usage = self._calculate_streaming_usage_base( + usage = UsageProcessor.calculate_streaming_usage( prompt_tokens, completion_tokens, cached_tokens, - request.n, + n_choices=request.n, + enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report, ) final_usage_chunk = CompletionStreamResponse( id=content["meta_info"]["id"], @@ -322,7 +321,9 @@ def _build_completion_response( # Calculate usage cache_report = self.tokenizer_manager.server_args.enable_cache_report - usage = aggregate_token_usage(ret, request.n, cache_report) + usage = UsageProcessor.calculate_response_usage( + ret, n_choices=request.n, enable_cache_report=cache_report + ) return CompletionResponse( id=ret[0]["meta_info"]["id"], diff --git a/python/sglang/srt/entrypoints/openai/usage_processor.py b/python/sglang/srt/entrypoints/openai/usage_processor.py new file mode 100644 index 00000000000..c8136829416 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/usage_processor.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Mapping, Optional, final + +from python.sglang.srt.entrypoints.openai.protocol import UsageInfo + + +@final +class UsageProcessor: + """Stateless helpers that turn raw token counts into a UsageInfo.""" + + @staticmethod + def _details_if_cached(count: int) -> Optional[Dict[str, int]]: + """Return {"cached_tokens": N} only when N > 0 (keeps JSON slim).""" + return {"cached_tokens": count} if count > 0 else None + + @staticmethod + def calculate_response_usage( + responses: List[Dict[str, Any]], + n_choices: int = 1, + enable_cache_report: bool = False, + ) -> UsageInfo: + completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses) + + prompt_tokens = sum( + responses[i]["meta_info"]["prompt_tokens"] + for i in range(0, len(responses), n_choices) + ) + + cached_details = None + if enable_cache_report: + cached_total = sum( + r["meta_info"].get("cached_tokens", 0) for r in responses + ) + cached_details = UsageProcessor._details_if_cached(cached_total) + + return UsageProcessor.calculate_token_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cached_tokens=cached_details, + ) + + @staticmethod + def calculate_streaming_usage( + prompt_tokens: Mapping[int, int], + completion_tokens: Mapping[int, int], + cached_tokens: Mapping[int, int], + n_choices: int, + enable_cache_report: bool = False, + ) -> UsageInfo: + # index % n_choices == 0 marks the first choice of a prompt + total_prompt_tokens = sum( + tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0 + ) + total_completion_tokens = sum(completion_tokens.values()) + + cached_details = ( + UsageProcessor._details_if_cached(sum(cached_tokens.values())) + if enable_cache_report + else None + ) + + return UsageProcessor.calculate_token_usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + cached_tokens=cached_details, + ) + + @staticmethod + def calculate_token_usage( + prompt_tokens: int, + completion_tokens: int, + cached_tokens: Optional[Dict[str, int]] = None, + ) -> UsageInfo: + """Calculate token usage information""" + return UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=cached_tokens, + ) diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 53c67831cdb..06e5e4dee10 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -1,10 +1,9 @@ import logging -from typing import Any, Dict, List, Optional import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils -from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo +from sglang.srt.entrypoints.openai.protocol import LogProbs logger = logging.getLogger(__name__) @@ -171,62 +170,6 @@ def process_content_for_template_format( return new_msg -def calculate_token_usage( - prompt_tokens: int, - completion_tokens: int, - cached_tokens: Optional[Dict[str, int]] = None, -) -> UsageInfo: - """Calculate token usage information""" - return UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - prompt_tokens_details=cached_tokens, - ) - - -def aggregate_token_usage( - responses: List[Dict[str, Any]], - n_choices: int = 1, - enable_cache_report: bool = False, -) -> UsageInfo: - """Aggregate token usage from multiple responses - - Args: - responses: List of response dictionaries with meta_info - n_choices: Number of choices per request (for prompt token counting) - enable_cache_report: Whether to include cached token details - - Returns: - Aggregated UsageInfo - """ - # Sum completion tokens from all responses - completion_tokens = sum( - response["meta_info"]["completion_tokens"] for response in responses - ) - - # For prompt tokens, only count every n_choices-th response to avoid double counting - prompt_tokens = sum( - responses[i]["meta_info"]["prompt_tokens"] - for i in range(0, len(responses), n_choices) - ) - - # Handle cached tokens if cache reporting is enabled - cached_tokens_details = None - if enable_cache_report: - cached_tokens_sum = sum( - response["meta_info"].get("cached_tokens", 0) for response in responses - ) - if cached_tokens_sum > 0: - cached_tokens_details = {"cached_tokens": cached_tokens_sum} - - return calculate_token_usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - cached_tokens=cached_tokens_details, - ) - - def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None,