Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions tensorrt_llm/evaluate/longbench_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(self,
output_dir: Optional[str] = None,
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None):
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
"""Initialize LongBench v2 evaluator.

Args:
Expand All @@ -85,10 +86,12 @@ def __init__(self,
random_seed: Random seed for reproducibility
apply_chat_template: Whether to apply model's chat template
system_prompt: System prompt to prepend
chat_template_kwargs: Chat template kwargs as JSON string
"""
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)

self.dataset_path = dataset_path
self.num_samples = num_samples
Expand Down Expand Up @@ -813,6 +816,15 @@ def _save_results(self, results: List[Dict], metrics: Dict[str, float]):
type=int,
default=32000,
help="Maximum generation length in sampling parameters.")
@click.option(
"--chat_template_kwargs",
type=str,
default=None,
callback=lambda ctx, param, value: json.loads(value) if value else None,
help=
'A JSON string specifying chat template arguments, used to enable features like thinking mode. Examples: '
'\'{"enable_thinking": true}\' for Qwen3, or \'{"thinking": true}\' for DeepSeek-V3.2.'
)
@click.pass_context
@staticmethod
def command(ctx, dataset_path: str, prompts_dir: Optional[str],
Expand All @@ -821,7 +833,8 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str],
cot: bool, no_context: bool, rag: int, max_len: int,
output_dir: Optional[str], random_seed: int,
apply_chat_template: bool, system_prompt: Optional[str],
max_input_length: int, max_output_length: int) -> None:
max_input_length: int, max_output_length: int,
chat_template_kwargs: Optional[dict[str, Any]]) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj

sampling_params = SamplingParams(
Expand All @@ -844,7 +857,8 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str],
output_dir=output_dir,
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)

evaluator.evaluate(llm, sampling_params)
llm.shutdown()