Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
cached_tokens_details_from_dict,
process_cached_tokens_details_from_ret,
process_hidden_states_from_ret,
process_routed_experts_from_ret,
Expand Down Expand Up @@ -652,6 +653,7 @@ async def _generate_chat_stream(
cached_tokens = {}
hidden_states = {}
routed_experts = {}
cached_tokens_details = {}

stream_started = False
try:
Expand All @@ -667,6 +669,9 @@ async def _generate_chat_stream(
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
routed_experts[index] = content["meta_info"].get("routed_experts", None)
cached_tokens_details[index] = content["meta_info"].get(
"cached_tokens_details", None
)

# Handle logprobs
choice_logprobs = None
Expand Down Expand Up @@ -865,20 +870,32 @@ async def _generate_chat_stream(
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"

sglext_routed = None
if request.return_routed_experts and routed_experts:
# Get first non-None routed_experts value
first_routed_experts = next(
sglext_routed = next(
(v for v in routed_experts.values() if v is not None), None
)
if first_routed_experts is not None:
routed_experts_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[], # sglext is at response level
model=request.model,
sglext=SglExt(routed_experts=first_routed_experts),
)
yield f"data: {routed_experts_chunk.model_dump_json()}\n\n"

sglext_details = None
if request.return_cached_tokens_details and cached_tokens_details:
first_details = next(
(v for v in cached_tokens_details.values() if v is not None), None
)
if first_details is not None:
sglext_details = cached_tokens_details_from_dict(first_details)

if sglext_routed is not None or sglext_details is not None:
sglext_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[], # sglext is at response level
model=request.model,
sglext=SglExt(
routed_experts=sglext_routed,
cached_tokens_details=sglext_details,
),
)
yield f"data: {sglext_chunk.model_dump_json()}\n\n"

# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
Expand Down
41 changes: 29 additions & 12 deletions python/sglang/srt/entrypoints/openai/serving_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
cached_tokens_details_from_dict,
process_cached_tokens_details_from_ret,
process_hidden_states_from_ret,
process_routed_experts_from_ret,
Expand Down Expand Up @@ -222,6 +223,7 @@ async def _generate_completion_stream(
cached_tokens = {}
hidden_states = {}
routed_experts = {}
cached_tokens_details = {}

stream_started = False
try:
Expand All @@ -238,6 +240,9 @@ async def _generate_completion_stream(
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
routed_experts[index] = content["meta_info"].get("routed_experts", None)
cached_tokens_details[index] = content["meta_info"].get(
"cached_tokens_details", None
)

stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk
Expand Down Expand Up @@ -354,21 +359,33 @@ async def _generate_completion_stream(
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"

sglext_routed = None
if request.return_routed_experts and routed_experts:
# Get first non-None routed_experts value
first_routed_experts = next(
sglext_routed = next(
(v for v in routed_experts.values() if v is not None), None
)
if first_routed_experts is not None:
routed_experts_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[], # sglext is at response level
model=request.model,
sglext=SglExt(routed_experts=first_routed_experts),
)
yield f"data: {routed_experts_chunk.model_dump_json()}\n\n"

sglext_details = None
if request.return_cached_tokens_details and cached_tokens_details:
first_details = next(
(v for v in cached_tokens_details.values() if v is not None), None
)
if first_details is not None:
sglext_details = cached_tokens_details_from_dict(first_details)

if sglext_routed is not None or sglext_details is not None:
sglext_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[], # sglext is at response level
model=request.model,
sglext=SglExt(
routed_experts=sglext_routed,
cached_tokens_details=sglext_details,
),
)
yield f"data: {sglext_chunk.model_dump_json()}\n\n"

# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
Expand Down
34 changes: 20 additions & 14 deletions python/sglang/srt/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ def process_routed_experts_from_ret(
return ret_item["meta_info"].get("routed_experts", None)


def cached_tokens_details_from_dict(
details: Dict[str, Any],
) -> CachedTokensDetails:
"""Convert a raw cached_tokens_details dict to a CachedTokensDetails object."""
if "storage" in details:
return CachedTokensDetails(
device=details.get("device", 0),
host=details.get("host", 0),
storage=details.get("storage", 0),
storage_backend=details.get("storage_backend"),
)
else:
return CachedTokensDetails(
device=details.get("device", 0),
host=details.get("host", 0),
)


def process_cached_tokens_details_from_ret(
ret_item: Dict[str, Any],
request: Union[
Expand All @@ -94,23 +112,11 @@ def process_cached_tokens_details_from_ret(
],
) -> Optional[CachedTokensDetails]:
"""Process cached tokens details from a ret item in non-streaming response."""
if not getattr(request, "return_cached_tokens_details", False):
if not request.return_cached_tokens_details:
return None

details = ret_item["meta_info"].get("cached_tokens_details", None)
if details is None:
return None

# Check if L3 storage fields are present
if "storage" in details:
return CachedTokensDetails(
device=details.get("device", 0),
host=details.get("host", 0),
storage=details.get("storage", 0),
storage_backend=details.get("storage_backend"),
)
else:
return CachedTokensDetails(
device=details.get("device", 0),
host=details.get("host", 0),
)
return cached_tokens_details_from_dict(details)
Loading