Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ class GuidedDecodingParams
/// @brief The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar.
/// EBNF grammar is widely-used to express context-free grammars.
kEBNF_GRAMMAR = 3,

/// @brief The generated text is amenable to the XGrammar structural tag.
kSTRUCTURAL_TAG = 4,
};

explicit GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide = std::nullopt);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ void initRequestBindings(pybind11::module_& m)
.value("JSON", tle::GuidedDecodingParams::GuideType::kJSON)
.value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
.value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX)
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR);
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);

auto guidedDecodingParamsGetstate
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ def create_py_executor_instance(dist,
if spec_config is not None:
raise ValueError(
"Guided decoding is not supported with speculative decoding.")
if pytorch_backend_config.enable_overlap_scheduler:
raise ValueError(
"Guided decoding is not supported with overlap scheduler.")

logger.info(
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}"
Expand Down
13 changes: 13 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import math
from typing import List, Optional

Expand Down Expand Up @@ -82,6 +83,18 @@ def build(self, scheduled_requests: ScheduledRequests,
grammar = xgrammar.Grammar.from_ebnf(guide)
compiled_grammar = self.xgrammar_compiler.compile_grammar(
grammar)
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
structural_tag_parameters = json.loads(guide)
structures = structural_tag_parameters["structures"]
structures = [
xgrammar.StructuralTagItem(
begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"]) for s in structures
]
triggers = structural_tag_parameters["triggers"]
compiled_grammar = self.xgrammar_compiler.compile_structural_tag(
structures, triggers)
case _:
raise ValueError(
f"Unrecognized guide type: {guide_type}.")
Expand Down
7 changes: 6 additions & 1 deletion tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class GuidedDecodingParams:
regex: Optional[str] = None
grammar: Optional[str] = None
json_object: bool = False
structural_tag: Optional[str] = None

def _validate(self):
num_guides = 0
Expand Down Expand Up @@ -451,7 +452,7 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
tllme.GuidedDecodingParams.GuideType.JSON)
elif self.guided_decoding.json is not None:
json_schema = self.guided_decoding.json
if isinstance(json, BaseModel):
if isinstance(json_schema, BaseModel):
json_schema = json_schema.model_json_schema()
if isinstance(json_schema, dict):
json_schema = json.dumps(json_schema)
Expand All @@ -465,5 +466,9 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR,
self.guided_decoding.grammar)
elif self.guided_decoding.structural_tag is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG,
self.guided_decoding.structural_tag)
else:
return None
57 changes: 36 additions & 21 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import Annotated, Required, TypedDict

from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams


class OpenAIBaseModel(BaseModel):
Expand Down Expand Up @@ -44,9 +44,17 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list)


class StructuralTag(OpenAIBaseModel):
begin: str
schema_: Optional[dict[str, Any]] = Field(alias="schema")
end: str


class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
# type must be "json_object" or "text" or "structural_tag"
type: Literal["text", "json_object", "structural_tag"]
structures: Optional[List[StructuralTag]] = None
triggers: Optional[List[str]] = None


class DisaggregatedParams(OpenAIBaseModel):
Expand Down Expand Up @@ -121,6 +129,23 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)


def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat]
) -> Optional[GuidedDecodingParams]:
if response_format is None:
return None
elif response_format.type == "text":
return None
elif response_format.type == "json_object":
return GuidedDecodingParams(json_object=True)
elif response_format.type == "structural_tag":
return GuidedDecodingParams(
structural_tag=response_format.model_dump_json(by_alias=True,
exclude_none=True))
else:
raise ValueError(f"Unsupported response format: {response_format.type}")


class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
Expand Down Expand Up @@ -170,10 +195,10 @@ class CompletionRequest(OpenAIBaseModel):
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=(
"Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
description=
("Similar to chat completion, this parameter specifies the format of "
"output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'} are "
"supported."),
)

disaggregated_params: Optional[DisaggregatedParams] = Field(
Expand Down Expand Up @@ -211,6 +236,8 @@ def to_sampling_params(self) -> SamplingParams:
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
return_context_logits=self.return_context_logits,
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),

# completion-extra-params
add_special_tokens=self.add_special_tokens,
Expand Down Expand Up @@ -255,13 +282,6 @@ def verify_multi_responses(cls, data):
raise ValueError("best_of should not be smaller than n")
return data

@model_validator(mode="before")
@classmethod
def check_response_format(cls, data):
if data.get("response_format"):
raise ValueError("response_format is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_suffix(cls, data):
Expand Down Expand Up @@ -520,6 +540,8 @@ def to_sampling_params(self) -> SamplingParams:
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),

# chat-completion-extra-params
add_special_tokens=self.add_special_tokens,
Expand Down Expand Up @@ -582,13 +604,6 @@ def verify_logit_processor(cls, data):
raise ValueError("logit bias is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_response_format(cls, data):
if data.get("response_format"):
raise ValueError("response_format is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_suffix(cls, data):
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,15 @@ def test_openai_chat_multimodal_example(llm_root, llm_venv):
str(test_root / "_test_openai_chat_multimodal.py")])


def test_openai_chat_structural_tag_example(llm_venv):
test_root = unittest_path() / "llmapi" / "apps"

llm_venv.run_cmd([
"-m", "pytest",
str(test_root / "_test_openai_chat_structural_tag.py")
])


@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(40000)
def test_openai_multi_chat_example(llm_root, llm_venv):
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ l0_a10:
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test]
- test_e2e.py::test_openai_chat_structural_tag_example
- condition:
ranges:
system_gpu_count:
Expand Down
Loading