Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions docs/advanced_features/structured_outputs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,44 @@
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=messages,\n",
" response_format={\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" }\n",
" },\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -594,6 +632,50 @@
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"payload = {\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" }\n",
" }\n",
" )\n",
" },\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -825,6 +907,51 @@
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" }\n",
" }\n",
" ),\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
25 changes: 15 additions & 10 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,22 @@ def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
if "format" in structural_tag:
# V1 format
ctx = self.grammar_compiler.compile_structural_tag(structural_tag)
else:
# Deprecated format
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class StructuralTagResponseFormat(BaseModel):
structures: List[StructuresResponseFormat]
triggers: List[str]

class StructuralTagResponseFormatV1(BaseModel):
type: Literal["structural_tag"]
format: Dict[str, Any]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just import StructuralTag from xgrammar, that's also a pydantic BaseModel and all its fields are type checked.


class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
Expand Down Expand Up @@ -219,7 +222,7 @@ class CompletionRequest(BaseModel):
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat, StructuralTagResponseFormatV1]] = None

# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
Expand Down Expand Up @@ -432,7 +435,7 @@ class ChatCompletionRequest(BaseModel):
)
n: int = 1
presence_penalty: float = 0.0
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat, StructuralTagResponseFormatV1]] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
Expand Down
Loading