Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 92 additions & 15 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory


def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]],
Expand Down Expand Up @@ -191,48 +192,122 @@ class CompletionStreamResponse(OpenAIBaseModel):


def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat]
response_format: Optional[ResponseFormat],
reasoning_parser: Optional[str] = None,
) -> Optional[GuidedDecodingParams]:
if response_format is None:
return None
guided_decoding_params = None
elif response_format.type == "text":
return None
guided_decoding_params = None
elif response_format.type == "json":
if response_format.schema is None:
raise ValueError(
"The 'schema' field is required when response_format.type is 'json'."
f"response_format.schema is required for response_format.type == {response_format.type!r}, but got None."
)
return GuidedDecodingParams(json=response_format.schema)
guided_decoding_params = GuidedDecodingParams(
json=response_format.schema)
elif response_format.type == "json_schema":
if response_format.json_schema is None:
raise ValueError(
"The 'json_schema' field is required when response_format.type is 'json_schema'."
f"response_format.json_schema is required for response_format.type == {response_format.type!r}, but got None."
)
return GuidedDecodingParams(json=response_format.json_schema)
guided_decoding_params = GuidedDecodingParams(
json=response_format.json_schema)
elif response_format.type == "json_object":
return GuidedDecodingParams(json_object=True)
guided_decoding_params = GuidedDecodingParams(json_object=True)
elif response_format.type == "regex":
return GuidedDecodingParams(regex=response_format.regex)
if response_format.regex is None:
raise ValueError(
f"response_format.regex is required for response_format.type == {response_format.type!r}, but got None."
)
guided_decoding_params = GuidedDecodingParams(
regex=response_format.regex)
elif response_format.type == "ebnf":
return GuidedDecodingParams(grammar=response_format.ebnf)
if response_format.ebnf is None:
raise ValueError(
f"response_format.ebnf is required for response_format.type == {response_format.type!r}, but got None."
)
guided_decoding_params = GuidedDecodingParams(
grammar=response_format.ebnf)
elif response_format.type == "structural_tag":
return GuidedDecodingParams(
guided_decoding_params = GuidedDecodingParams(
structural_tag=response_format.model_dump_json(by_alias=True,
exclude_none=True))
else:
raise ValueError(f"Unsupported response format: {response_format.type}")

if guided_decoding_params is None or reasoning_parser is None:
return guided_decoding_params

if guided_decoding_params.structural_tag is not None:
return guided_decoding_params

# Adapt guided_decoding_params for reasoning parser
if guided_decoding_params.json is not None:
content = {
"type": "json_schema",
"json_schema": guided_decoding_params.json
}
elif guided_decoding_params.json_object:
content = {"type": "json_schema", "json_schema": {"type": "object"}}
elif guided_decoding_params.regex is not None:
content = {"type": "regex", "pattern": guided_decoding_params.regex}
elif guided_decoding_params.grammar is not None:
content = {"type": "grammar", "grammar": guided_decoding_params.grammar}

if reasoning_parser == "gpt_oss":
# Trigger user constraint by final channel
stag_format = {
"type":
"triggered_tags",
"triggers": ["<|start|>assistant<|channel|>final<|message|>"],
"tags": [
{
"begin": "<|start|>assistant<|channel|>final<|message|>",
"content": content,
"end": "",
},
],
"stop_after_first":
True,
}
else:
# Force thinking and then trigger user constraint
parser = ReasoningParserFactory.create_reasoning_parser(
reasoning_parser)
stag_format = {
"type":
"sequence",
"elements": [
{
"type": "tag",
"begin": parser.reasoning_start,
"content": {
"type": "any_text"
},
"end": parser.reasoning_end,
},
content,
],
}

stag_format = ResponseFormat(type="structural_tag", format=stag_format)
return GuidedDecodingParams(structural_tag=stag_format.model_dump_json(
by_alias=True, exclude_none=True))


def _response_format_text_config_to_guided_decoding_params(
text_format: Optional[ResponseFormatTextConfig]
text_format: Optional[ResponseFormatTextConfig],
reasoning_parser: Optional[str] = None,
) -> Optional[GuidedDecodingParams]:
if text_format is None:
return None

resp_format = ResponseFormat(type=text_format.type,
json_schema=getattr(text_format, "schema_",
None))
return _response_format_to_guided_decoding_params(resp_format)
return _response_format_to_guided_decoding_params(
resp_format, reasoning_parser=reasoning_parser)


class CompletionRequest(OpenAIBaseModel):
Expand Down Expand Up @@ -649,6 +724,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
def to_sampling_params(self,
vocab_size: int = 32000,
gather_generation_logits: bool = False,
reasoning_parser: Optional[str] = None,
backend: Optional[str] = None) -> SamplingParams:
sampling_params = SamplingParams(
frequency_penalty=self.frequency_penalty,
Expand Down Expand Up @@ -679,7 +755,7 @@ def to_sampling_params(self,
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),
self.response_format, reasoning_parser=reasoning_parser),

# logits_bias
embedding_bias=_logit_bias_to_embedding_bias(
Expand Down Expand Up @@ -809,6 +885,7 @@ class ResponsesRequest(OpenAIBaseModel):
def to_sampling_params(
self,
default_sampling_params: Optional[dict] = None,
reasoning_parser: Optional[str] = None,
) -> SamplingParams:
max_tokens = None
if self.max_output_tokens is not None:
Expand All @@ -827,7 +904,7 @@ def to_sampling_params(
guided_decoding = None
if self.text is not None and self.text.format is not None:
guided_decoding = _response_format_text_config_to_guided_decoding_params(
self.text.format)
self.text.format, reasoning_parser=reasoning_parser)

return SamplingParams(
temperature=temperature,
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ async def create_chat_response(
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
gather_generation_logits=self.llm.args.gather_generation_logits,
reasoning_parser=self.llm.args.reasoning_parser,
backend=self.llm.args.backend)
postproc_args = ChatPostprocArgs.from_request(request)
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
Expand Down Expand Up @@ -916,7 +917,8 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po
request.stop_token_ids = harmony_stop_tokens

sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size)
vocab_size=self.tokenizer.tokenizer.vocab_size,
reasoning_parser="gpt_oss")
sampling_params.detokenize = False # Harmony adapter handles detokenization

postproc_args = ChatCompletionPostprocArgs.from_request(request)
Expand Down Expand Up @@ -1018,6 +1020,7 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po
tokenizer=self.tokenizer if not self.use_harmony else None,
model_config=self.model_config if not self.use_harmony else None,
processor=self.processor if not self.use_harmony else None,
reasoning_parser=self.llm.args.reasoning_parser if not self.use_harmony else "gpt_oss",
)

streaming_processor = None
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/serve/responses_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,13 +867,16 @@ async def request_preprocess(
tokenizer: Optional[Union[TransformersTokenizer, TokenizerBase]] = None,
model_config: Optional[PretrainedConfig] = None,
processor: Optional[AutoProcessor] = None,
reasoning_parser: Optional[str] = None,
) -> tuple[list[int], SamplingParams]:

sampling_params = request.to_sampling_params(
default_sampling_params={
"stop_token_ids":
get_harmony_adapter().get_stop_tokens() if use_harmony else []
})
},
reasoning_parser=reasoning_parser,
)

prev_response_id = request.previous_response_id

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/defs/.test_durations
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@
"test_e2e.py::test_mistral_e2e[use_py_session-remove_input_padding--]": 157.39577213302255,
"test_e2e.py::test_mistral_large_hidden_vocab_size": 81.36711680702865,
"test_e2e.py::test_openai_chat_example": 876.1966922096908,
"test_e2e.py::test_openai_chat_guided_decoding": 55.12449237401597,
"test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]": 55.12449237401597,
"test_e2e.py::test_openai_chat_harmony": 1162.7252594940364,
"test_e2e.py::test_openai_chat_multimodal_example": 215.8254322744906,
"test_e2e.py::test_openai_consistent_chat": 0.0001894170418381691,
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,11 +1736,14 @@ def test_openai_mmencoder_example(llm_root, llm_venv):
str(test_root / "_test_openai_mmencoder.py")])


def test_openai_chat_guided_decoding(llm_root, llm_venv):
@pytest.mark.parametrize(
"model_name", ["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"])
def test_openai_chat_guided_decoding(llm_root, llm_venv, model_name: str):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd([
"-m", "pytest",
str(test_root / "_test_openai_chat_guided_decoding.py")
str(test_root / "_test_openai_chat_guided_decoding.py"), "-k",
model_name
])


Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ test_e2e.py::test_mistral_e2e[use_py_session---]
test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B]
test_e2e.py::test_openai_multi_chat_example
test_e2e.py::test_openai_consistent_chat
test_e2e.py::test_openai_chat_guided_decoding
test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]
test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b]
test_e2e.py::test_openai_chat_harmony
test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_trtllm_benchmark_serving[gpt_oss/gpt-oss-20b]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ l0_b200:
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
- test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b]
- unittest/_torch/attention
- unittest/_torch/compilation
- unittest/_torch/debugger
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ l0_h100:
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
- test_e2e.py::test_openai_chat_harmony
- test_e2e.py::test_openai_responses
- test_e2e.py::test_openai_chat_guided_decoding
- test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]
- test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B]
- condition:
ranges:
Expand Down
27 changes: 20 additions & 7 deletions tests/unittest/llmapi/apps/_test_openai_chat_guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,28 @@
pytestmark = pytest.mark.threadleak(enabled=False)


@pytest.fixture(scope="module")
def model_name():
return "llama-3.1-model/Llama-3.1-8B-Instruct"
@pytest.fixture(
scope="module",
params=["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"])
def model_name(request):
return request.param


@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file():
def temp_extra_llm_api_options_file(model_name: str):
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
if model_name == "openai/gpt-oss-120b":
extra_llm_api_options_dict["speculative_config"] = {
"decoding_type":
"Eagle",
"max_draft_len":
3,
"speculative_model_dir":
get_model_path("gpt_oss/gpt-oss-120b-Eagle3"),
}

with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
Expand All @@ -39,11 +50,13 @@ def temp_extra_llm_api_options_file():

@pytest.fixture(scope="module")
def server(model_name: str, temp_extra_llm_api_options_file: str):
model_path = get_model_path(model_name)
if model_name == "meta-llama/Llama-3.1-8B-Instruct":
model_path = get_model_path("llama-3.1-model/Llama-3.1-8B-Instruct")
elif model_name == "openai/gpt-oss-120b":
model_path = get_model_path("gpt_oss/gpt-oss-120b")

# Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs.
args = [
"--max_batch_size=8", "--max_seq_len=1024", "--max_num_tokens=1024",
"--max_batch_size=8", "--max_seq_len=4096", "--max_num_tokens=4096",
f"--extra_llm_api_options={temp_extra_llm_api_options_file}"
]
with RemoteOpenAIServer(model_path, args) as remote_server:
Expand Down