@@ -998,7 +998,6 @@ def get_trtllm_bench_model(self):
998998
999999 def get_trtllm_bench_build_command (self , engine_dir ) -> list :
10001000 model_dir = self .get_trtllm_bench_model ()
1001- dataset_path = os .path .join (engine_dir , "synthetic_data.json" )
10021001 if model_dir == "" :
10031002 pytest .skip ("Model Name is not supported by trtllm-bench" )
10041003 model_name = self ._config .model_name
@@ -1008,13 +1007,19 @@ def get_trtllm_bench_build_command(self, engine_dir) -> list:
10081007 build_cmd = [
10091008 self ._build_script , f"--log_level=info" ,
10101009 f"--workspace={ engine_dir } " , f"--model={ hf_model_name } " ,
1011- f"--model_path={ model_dir } " , "build" , f"--dataset= { dataset_path } " ,
1010+ f"--model_path={ model_dir } " , "build" ,
10121011 f"--tp_size={ self ._config .tp_size } " ,
10131012 f"--pp_size={ self ._config .pp_size } "
10141013 ]
10151014 max_seq_len = max (self ._config .input_lens ) + max (
10161015 self ._config .output_lens )
10171016 build_cmd .append (f"--max_seq_len={ max_seq_len } " )
1017+ # Add max_batch_size and max_num_tokens to ensure build matches runtime configuration
1018+ # Note: trtllm-bench requires both to be specified together (option group constraint)
1019+ assert self ._config .max_batch_size > 0 , f"max_batch_size must be > 0, got { self ._config .max_batch_size } "
1020+ assert self ._config .max_num_tokens > 0 , f"max_num_tokens must be > 0, got { self ._config .max_num_tokens } "
1021+ build_cmd .append (f"--max_batch_size={ self ._config .max_batch_size } " )
1022+ build_cmd .append (f"--max_num_tokens={ self ._config .max_num_tokens } " )
10181023 if self ._config .quantization :
10191024 build_cmd .append (
10201025 f"--quantization={ self ._config .quantization .upper ()} " )
0 commit comments