diff --git a/tensorrt_llm/evaluate/longbench_v2.py b/tensorrt_llm/evaluate/longbench_v2.py index 2e11bccff07..503e1bac7d0 100644 --- a/tensorrt_llm/evaluate/longbench_v2.py +++ b/tensorrt_llm/evaluate/longbench_v2.py @@ -66,7 +66,8 @@ def __init__(self, output_dir: Optional[str] = None, random_seed: int = 0, apply_chat_template: bool = False, - system_prompt: Optional[str] = None): + system_prompt: Optional[str] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None): """Initialize LongBench v2 evaluator. Args: @@ -85,10 +86,12 @@ def __init__(self, random_seed: Random seed for reproducibility apply_chat_template: Whether to apply model's chat template system_prompt: System prompt to prepend + chat_template_kwargs: Chat template kwargs as JSON string """ super().__init__(random_seed=random_seed, apply_chat_template=apply_chat_template, - system_prompt=system_prompt) + system_prompt=system_prompt, + chat_template_kwargs=chat_template_kwargs) self.dataset_path = dataset_path self.num_samples = num_samples @@ -813,6 +816,15 @@ def _save_results(self, results: List[Dict], metrics: Dict[str, float]): type=int, default=32000, help="Maximum generation length in sampling parameters.") + @click.option( + "--chat_template_kwargs", + type=str, + default=None, + callback=lambda ctx, param, value: json.loads(value) if value else None, + help= + 'A JSON string specifying chat template arguments, used to enable features like thinking mode. Examples: ' + '\'{"enable_thinking": true}\' for Qwen3, or \'{"thinking": true}\' for DeepSeek-V3.2.' + ) @click.pass_context @staticmethod def command(ctx, dataset_path: str, prompts_dir: Optional[str], @@ -821,7 +833,8 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str], cot: bool, no_context: bool, rag: int, max_len: int, output_dir: Optional[str], random_seed: int, apply_chat_template: bool, system_prompt: Optional[str], - max_input_length: int, max_output_length: int) -> None: + max_input_length: int, max_output_length: int, + chat_template_kwargs: Optional[dict[str, Any]]) -> None: llm: Union[LLM, PyTorchLLM] = ctx.obj sampling_params = SamplingParams( @@ -844,7 +857,8 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str], output_dir=output_dir, random_seed=random_seed, apply_chat_template=apply_chat_template, - system_prompt=system_prompt) + system_prompt=system_prompt, + chat_template_kwargs=chat_template_kwargs) evaluator.evaluate(llm, sampling_params) llm.shutdown()