diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 8ddda27cd7f..0d91c4c01b2 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -37,6 +37,7 @@ from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams +from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory def _logit_bias_to_embedding_bias(logit_bias: Optional[Dict[str, float]], @@ -191,40 +192,113 @@ class CompletionStreamResponse(OpenAIBaseModel): def _response_format_to_guided_decoding_params( - response_format: Optional[ResponseFormat] + response_format: Optional[ResponseFormat], + reasoning_parser: Optional[str] = None, ) -> Optional[GuidedDecodingParams]: if response_format is None: - return None + guided_decoding_params = None elif response_format.type == "text": - return None + guided_decoding_params = None elif response_format.type == "json": if response_format.schema is None: raise ValueError( - "The 'schema' field is required when response_format.type is 'json'." + f"response_format.schema is required for response_format.type == {response_format.type!r}, but got None." ) - return GuidedDecodingParams(json=response_format.schema) + guided_decoding_params = GuidedDecodingParams( + json=response_format.schema) elif response_format.type == "json_schema": if response_format.json_schema is None: raise ValueError( - "The 'json_schema' field is required when response_format.type is 'json_schema'." + f"response_format.json_schema is required for response_format.type == {response_format.type!r}, but got None." ) - return GuidedDecodingParams(json=response_format.json_schema) + guided_decoding_params = GuidedDecodingParams( + json=response_format.json_schema) elif response_format.type == "json_object": - return GuidedDecodingParams(json_object=True) + guided_decoding_params = GuidedDecodingParams(json_object=True) elif response_format.type == "regex": - return GuidedDecodingParams(regex=response_format.regex) + if response_format.regex is None: + raise ValueError( + f"response_format.regex is required for response_format.type == {response_format.type!r}, but got None." + ) + guided_decoding_params = GuidedDecodingParams( + regex=response_format.regex) elif response_format.type == "ebnf": - return GuidedDecodingParams(grammar=response_format.ebnf) + if response_format.ebnf is None: + raise ValueError( + f"response_format.ebnf is required for response_format.type == {response_format.type!r}, but got None." + ) + guided_decoding_params = GuidedDecodingParams( + grammar=response_format.ebnf) elif response_format.type == "structural_tag": - return GuidedDecodingParams( + guided_decoding_params = GuidedDecodingParams( structural_tag=response_format.model_dump_json(by_alias=True, exclude_none=True)) else: raise ValueError(f"Unsupported response format: {response_format.type}") + if guided_decoding_params is None or reasoning_parser is None: + return guided_decoding_params + + if guided_decoding_params.structural_tag is not None: + return guided_decoding_params + + # Adapt guided_decoding_params for reasoning parser + if guided_decoding_params.json is not None: + content = { + "type": "json_schema", + "json_schema": guided_decoding_params.json + } + elif guided_decoding_params.json_object: + content = {"type": "json_schema", "json_schema": {"type": "object"}} + elif guided_decoding_params.regex is not None: + content = {"type": "regex", "pattern": guided_decoding_params.regex} + elif guided_decoding_params.grammar is not None: + content = {"type": "grammar", "grammar": guided_decoding_params.grammar} + + if reasoning_parser == "gpt_oss": + # Trigger user constraint by final channel + stag_format = { + "type": + "triggered_tags", + "triggers": ["<|start|>assistant<|channel|>final<|message|>"], + "tags": [ + { + "begin": "<|start|>assistant<|channel|>final<|message|>", + "content": content, + "end": "", + }, + ], + "stop_after_first": + True, + } + else: + # Force thinking and then trigger user constraint + parser = ReasoningParserFactory.create_reasoning_parser( + reasoning_parser) + stag_format = { + "type": + "sequence", + "elements": [ + { + "type": "tag", + "begin": parser.reasoning_start, + "content": { + "type": "any_text" + }, + "end": parser.reasoning_end, + }, + content, + ], + } + + stag_format = ResponseFormat(type="structural_tag", format=stag_format) + return GuidedDecodingParams(structural_tag=stag_format.model_dump_json( + by_alias=True, exclude_none=True)) + def _response_format_text_config_to_guided_decoding_params( - text_format: Optional[ResponseFormatTextConfig] + text_format: Optional[ResponseFormatTextConfig], + reasoning_parser: Optional[str] = None, ) -> Optional[GuidedDecodingParams]: if text_format is None: return None @@ -232,7 +306,8 @@ def _response_format_text_config_to_guided_decoding_params( resp_format = ResponseFormat(type=text_format.type, json_schema=getattr(text_format, "schema_", None)) - return _response_format_to_guided_decoding_params(resp_format) + return _response_format_to_guided_decoding_params( + resp_format, reasoning_parser=reasoning_parser) class CompletionRequest(OpenAIBaseModel): @@ -649,6 +724,7 @@ class ChatCompletionRequest(OpenAIBaseModel): def to_sampling_params(self, vocab_size: int = 32000, gather_generation_logits: bool = False, + reasoning_parser: Optional[str] = None, backend: Optional[str] = None) -> SamplingParams: sampling_params = SamplingParams( frequency_penalty=self.frequency_penalty, @@ -679,7 +755,7 @@ def to_sampling_params(self, spaces_between_special_tokens=self.spaces_between_special_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens, guided_decoding=_response_format_to_guided_decoding_params( - self.response_format), + self.response_format, reasoning_parser=reasoning_parser), # logits_bias embedding_bias=_logit_bias_to_embedding_bias( @@ -809,6 +885,7 @@ class ResponsesRequest(OpenAIBaseModel): def to_sampling_params( self, default_sampling_params: Optional[dict] = None, + reasoning_parser: Optional[str] = None, ) -> SamplingParams: max_tokens = None if self.max_output_tokens is not None: @@ -827,7 +904,7 @@ def to_sampling_params( guided_decoding = None if self.text is not None and self.text.format is not None: guided_decoding = _response_format_text_config_to_guided_decoding_params( - self.text.format) + self.text.format, reasoning_parser=reasoning_parser) return SamplingParams( temperature=temperature, diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index afb97aa6f0c..4c441237a67 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -540,6 +540,7 @@ async def create_chat_response( sampling_params = request.to_sampling_params( vocab_size=self.tokenizer.tokenizer.vocab_size, gather_generation_logits=self.llm.args.gather_generation_logits, + reasoning_parser=self.llm.args.reasoning_parser, backend=self.llm.args.backend) postproc_args = ChatPostprocArgs.from_request(request) disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) @@ -916,7 +917,8 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po request.stop_token_ids = harmony_stop_tokens sampling_params = request.to_sampling_params( - vocab_size=self.tokenizer.tokenizer.vocab_size) + vocab_size=self.tokenizer.tokenizer.vocab_size, + reasoning_parser="gpt_oss") sampling_params.detokenize = False # Harmony adapter handles detokenization postproc_args = ChatCompletionPostprocArgs.from_request(request) @@ -1018,6 +1020,7 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po tokenizer=self.tokenizer if not self.use_harmony else None, model_config=self.model_config if not self.use_harmony else None, processor=self.processor if not self.use_harmony else None, + reasoning_parser=self.llm.args.reasoning_parser if not self.use_harmony else "gpt_oss", ) streaming_processor = None diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index b760dd60369..35a999ac9a5 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -867,13 +867,16 @@ async def request_preprocess( tokenizer: Optional[Union[TransformersTokenizer, TokenizerBase]] = None, model_config: Optional[PretrainedConfig] = None, processor: Optional[AutoProcessor] = None, + reasoning_parser: Optional[str] = None, ) -> tuple[list[int], SamplingParams]: sampling_params = request.to_sampling_params( default_sampling_params={ "stop_token_ids": get_harmony_adapter().get_stop_tokens() if use_harmony else [] - }) + }, + reasoning_parser=reasoning_parser, + ) prev_response_id = request.previous_response_id diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index bbb44110638..d0abdb6eaee 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -831,7 +831,7 @@ "test_e2e.py::test_mistral_e2e[use_py_session-remove_input_padding--]": 157.39577213302255, "test_e2e.py::test_mistral_large_hidden_vocab_size": 81.36711680702865, "test_e2e.py::test_openai_chat_example": 876.1966922096908, - "test_e2e.py::test_openai_chat_guided_decoding": 55.12449237401597, + "test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct]": 55.12449237401597, "test_e2e.py::test_openai_chat_harmony": 1162.7252594940364, "test_e2e.py::test_openai_chat_multimodal_example": 215.8254322744906, "test_e2e.py::test_openai_consistent_chat": 0.0001894170418381691, diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index b292d49f70d..d37df0bed2c 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1736,11 +1736,14 @@ def test_openai_mmencoder_example(llm_root, llm_venv): str(test_root / "_test_openai_mmencoder.py")]) -def test_openai_chat_guided_decoding(llm_root, llm_venv): +@pytest.mark.parametrize( + "model_name", ["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"]) +def test_openai_chat_guided_decoding(llm_root, llm_venv, model_name: str): test_root = unittest_path() / "llmapi" / "apps" llm_venv.run_cmd([ "-m", "pytest", - str(test_root / "_test_openai_chat_guided_decoding.py") + str(test_root / "_test_openai_chat_guided_decoding.py"), "-k", + model_name ]) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index cf7157dc693..0a50758ac2e 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -342,7 +342,8 @@ test_e2e.py::test_mistral_e2e[use_py_session---] test_e2e.py::test_qwen_e2e_cpprunner_large_new_tokens[DeepSeek-R1-Distill-Qwen-1.5B-DeepSeek-R1-Distill-Qwen-1.5B] test_e2e.py::test_openai_multi_chat_example test_e2e.py::test_openai_consistent_chat -test_e2e.py::test_openai_chat_guided_decoding +test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct] +test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b] test_e2e.py::test_openai_chat_harmony test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B] test_e2e.py::test_trtllm_benchmark_serving[gpt_oss/gpt-oss-20b] diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 86b504a580d..99da33194fb 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -67,6 +67,7 @@ l0_b200: - test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B] - test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct] - test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False] + - test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b] - unittest/_torch/attention - unittest/_torch/compilation - unittest/_torch/debugger diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index ee36c87a0d1..1a4617fd987 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -110,7 +110,7 @@ l0_h100: - test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B] - test_e2e.py::test_openai_chat_harmony - test_e2e.py::test_openai_responses - - test_e2e.py::test_openai_chat_guided_decoding + - test_e2e.py::test_openai_chat_guided_decoding[meta-llama/Llama-3.1-8B-Instruct] - test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B] - condition: ranges: diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_guided_decoding.py b/tests/unittest/llmapi/apps/_test_openai_chat_guided_decoding.py index ccfe1d02e22..86a5cc8a3a5 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_guided_decoding.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_guided_decoding.py @@ -16,17 +16,28 @@ pytestmark = pytest.mark.threadleak(enabled=False) -@pytest.fixture(scope="module") -def model_name(): - return "llama-3.1-model/Llama-3.1-8B-Instruct" +@pytest.fixture( + scope="module", + params=["meta-llama/Llama-3.1-8B-Instruct", "openai/gpt-oss-120b"]) +def model_name(request): + return request.param @pytest.fixture(scope="module") -def temp_extra_llm_api_options_file(): +def temp_extra_llm_api_options_file(model_name: str): temp_dir = tempfile.gettempdir() temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") try: extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"} + if model_name == "openai/gpt-oss-120b": + extra_llm_api_options_dict["speculative_config"] = { + "decoding_type": + "Eagle", + "max_draft_len": + 3, + "speculative_model_dir": + get_model_path("gpt_oss/gpt-oss-120b-Eagle3"), + } with open(temp_file_path, 'w') as f: yaml.dump(extra_llm_api_options_dict, f) @@ -39,11 +50,13 @@ def temp_extra_llm_api_options_file(): @pytest.fixture(scope="module") def server(model_name: str, temp_extra_llm_api_options_file: str): - model_path = get_model_path(model_name) + if model_name == "meta-llama/Llama-3.1-8B-Instruct": + model_path = get_model_path("llama-3.1-model/Llama-3.1-8B-Instruct") + elif model_name == "openai/gpt-oss-120b": + model_path = get_model_path("gpt_oss/gpt-oss-120b") - # Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs. args = [ - "--max_batch_size=8", "--max_seq_len=1024", "--max_num_tokens=1024", + "--max_batch_size=8", "--max_seq_len=4096", "--max_num_tokens=4096", f"--extra_llm_api_options={temp_extra_llm_api_options_file}" ] with RemoteOpenAIServer(model_path, args) as remote_server: