Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions docs/advanced_features/structured_outputs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,50 @@
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=messages,\n",
" response_format={\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_weather,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_date,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" },\n",
" },\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -594,6 +638,56 @@
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"payload = {\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_weather,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_date,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" },\n",
" }\n",
" )\n",
" },\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -825,6 +919,57 @@
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_weather,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_date,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" },\n",
" }\n",
" ),\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/constrained/llguidance_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BaseGrammarBackend,
BaseGrammarObject,
)
from sglang.srt.constrained.utils import is_legacy_structural_tag

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -160,6 +161,7 @@ def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
structural_tag = json.loads(key_string)
assert is_legacy_structural_tag(structural_tag)
tags = [
StructTag(
begin=structure["begin"],
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/constrained/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Dict


def is_legacy_structural_tag(obj: Dict) -> bool:
# test whether an object is a legacy structural tag
# see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
if obj.get("structures", None) is not None:
assert obj.get("triggers", None) is not None
return True
else:
assert obj.get("format", None) is not None
return False
Comment thread
DarkSharpness marked this conversation as resolved.
25 changes: 15 additions & 10 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BaseGrammarObject,
GrammarStats,
)
from sglang.srt.constrained.utils import is_legacy_structural_tag
from sglang.srt.utils import is_hip

_is_hip = is_hip()
Expand Down Expand Up @@ -241,18 +242,22 @@ def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:

def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
# TODO(dark): it's REALLY stupid to construct object from string and decode it again
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
if is_legacy_structural_tag(structural_tag):
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
else:
ctx = self.grammar_compiler.compile_structural_tag(key_string)
except (RuntimeError, json.decoder.JSONDecodeError) as e:
Comment thread
DarkSharpness marked this conversation as resolved.
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
Expand Down
22 changes: 17 additions & 5 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union

from openai.types.responses import (
ResponseFunctionToolCall,
Expand All @@ -37,6 +37,7 @@
model_validator,
)
from typing_extensions import Literal
from xgrammar import StructuralTag

from sglang.utils import convert_json_schema_to_str

Expand Down Expand Up @@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel):
end: str


class StructuralTagResponseFormat(BaseModel):
# NOTE(dark): keep this for backward compatibility
class LegacyStructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]


StructuralTagResponseFormat: TypeAlias = Union[
LegacyStructuralTagResponseFormat, StructuralTag
]

ToolCallConstraint: TypeAlias = Union[
Tuple[Literal["structural_tag"], StructuralTagResponseFormat],
Tuple[Literal["json_schema"], Any], # json_schema can be dict/str/None
]

Comment thread
DarkSharpness marked this conversation as resolved.

class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
Expand Down Expand Up @@ -583,7 +595,7 @@ def to_sampling_params(
self,
stop: List[str],
model_generation_config: Dict[str, Any],
tool_call_constraint: Optional[Any] = None,
tool_call_constraint: Optional[ToolCallConstraint] = None,
) -> Dict[str, Any]:
"""
Convert request to sampling parameters.
Expand Down Expand Up @@ -649,7 +661,7 @@ def get_param(param_name: str):
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
constraint_value # type: ignore
)
else:
sampling_params[constraint_type] = constraint_value
Expand Down Expand Up @@ -1145,7 +1157,7 @@ class MessageProcessingResult:
video_data: Optional[Any]
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None
tool_call_constraint: Optional[ToolCallConstraint] = None


class ToolCallProcessingResult(NamedTuple):
Expand Down
16 changes: 9 additions & 7 deletions python/sglang/srt/function_call/function_call_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union

from sglang.srt.entrypoints.openai.protocol import (
StructuralTagResponseFormat,
LegacyStructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolCallConstraint,
ToolChoice,
)
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
Expand Down Expand Up @@ -51,7 +52,6 @@ class FunctionCallParser:
}

def __init__(self, tools: List[Tool], tool_call_parser: str):
detector: Type[BaseFormatDetector] = None
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class:
detector = detector_class()
Expand Down Expand Up @@ -123,7 +123,7 @@ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:

return final_normal_text, final_calls

def get_structure_tag(self) -> StructuralTagResponseFormat:
def get_structure_tag(self) -> LegacyStructuralTagResponseFormat:
"""
Generate a structural tag response format for all available tools.

Expand Down Expand Up @@ -151,15 +151,17 @@ def get_structure_tag(self) -> StructuralTagResponseFormat:
)
tool_trigger_set.add(info.trigger)

return StructuralTagResponseFormat(
# TODO(dark): move this into new structural tag format
# This requires all grammar backend support the new format
return LegacyStructuralTagResponseFormat(
type="structural_tag",
structures=tool_structures,
triggers=list(tool_trigger_set),
)

def get_structure_constraint(
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
) -> Optional[Tuple[str, Any]]:
) -> Optional[ToolCallConstraint]:
"""
Returns the appropriate structure constraint for tool calls based on the tool_choice.
The constraint is used to guide the model's output format.
Expand All @@ -178,8 +180,8 @@ def get_structure_constraint(
and tool_choice == "auto"
and any(tool.function.strict for tool in self.tools)
):
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
tag = self.get_structure_tag()
return ("structural_tag", tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
json_schema = get_json_schema_constraint(self.tools, tool_choice)
return ("json_schema", json_schema)
Expand Down
Loading
Loading