Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ class GuidedDecodingParams
/// @brief The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar.
/// EBNF grammar is widely-used to express context-free grammars.
kEBNF_GRAMMAR = 3,

/// @brief The generated text is amenable to the XGrammar structural tag.
kSTRUCTURAL_TAG = 4,
};

explicit GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide = std::nullopt);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ void initRequestBindings(pybind11::module_& m)
.value("JSON", tle::GuidedDecodingParams::GuideType::kJSON)
.value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
.value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX)
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR);
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);

auto guidedDecodingParamsGetstate
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import math
from typing import List, Optional

Expand Down Expand Up @@ -82,6 +83,12 @@ def build(self, scheduled_requests: ScheduledRequests,
grammar = xgrammar.Grammar.from_ebnf(guide)
compiled_grammar = self.xgrammar_compiler.compile_grammar(
grammar)
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
structural_tag_parameters = json.loads(guide)
structures = structural_tag_parameters["structures"]
triggers = structural_tag_parameters["triggers"]
compiled_grammar = self.xgrammar_compiler.compile_structural_tag(
structures, triggers)
case _:
raise ValueError(
f"Unrecognized guide type: {guide_type}.")
Expand Down
7 changes: 6 additions & 1 deletion tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GuidedDecodingParams:
regex: Optional[str] = None
grammar: Optional[str] = None
json_object: bool = False
structural_tag: Optional[str] = None

def _validate(self):
num_guides = 0
Expand Down Expand Up @@ -441,7 +442,7 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
tllme.GuidedDecodingParams.GuideType.JSON)
elif self.guided_decoding.json is not None:
json_schema = self.guided_decoding.json
if isinstance(json, BaseModel):
if isinstance(json_schema, BaseModel):
json_schema = json_schema.model_json_schema()
if isinstance(json_schema, dict):
json_schema = json.dumps(json_schema)
Expand All @@ -455,5 +456,9 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR,
self.guided_decoding.grammar)
elif self.guided_decoding.structural_tag is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG,
self.guided_decoding.structural_tag)
else:
return None
45 changes: 29 additions & 16 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import Annotated, Required, TypedDict

from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams


class OpenAIBaseModel(BaseModel):
Expand Down Expand Up @@ -44,9 +44,17 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list)


class StructuralTag(OpenAIBaseModel):
begin: str
schema_: Optional[dict[str, Any]] = Field(alias="schema")
end: str


class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
type: Literal["text", "json_object", "structural_tag"]
structures: Optional[List[StructuralTag]] = None
triggers: Optional[List[str]] = None


class DisaggregatedParams(OpenAIBaseModel):
Expand Down Expand Up @@ -121,6 +129,21 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)


def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat]
) -> Optional[GuidedDecodingParams]:
if response_format is None:
return None
if response_format.type == "json_object":
return GuidedDecodingParams(json_object=True)
elif response_format.type == "structural_tag":
return GuidedDecodingParams(
structural_tag=response_format.model_dump_json(by_alias=True,
exclude_none=True))
else:
raise ValueError(f"Unsupported response format: {response_format.type}")


class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
Expand Down Expand Up @@ -212,6 +235,8 @@ def to_sampling_params(self) -> SamplingParams:
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
return_context_logits=self.return_context_logits,
guided_decoding_params=_response_format_to_guided_decoding_params(
self.response_format),

# completion-extra-params
add_special_tokens=self.add_special_tokens,
Expand Down Expand Up @@ -254,13 +279,6 @@ def verify_multi_responses(cls, data):
raise ValueError("best_of should not be smaller than n")
return data

@model_validator(mode="before")
@classmethod
def check_response_format(cls, data):
if data.get("response_format"):
raise ValueError("response_format is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_suffix(cls, data):
Expand Down Expand Up @@ -527,6 +545,8 @@ def to_sampling_params(self) -> SamplingParams:
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
guided_decoding_params=_response_format_to_guided_decoding_params(
self.response_format),

# chat-completion-extra-params
add_special_tokens=self.add_special_tokens,
Expand Down Expand Up @@ -586,13 +606,6 @@ def verify_logit_processor(cls, data):
raise ValueError("logit bias is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_response_format(cls, data):
if data.get("response_format"):
raise ValueError("response_format is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_suffix(cls, data):
Expand Down