Skip to content

Commit 285ae6e

Browse files
venkywonkalancelly
authored andcommitted
[fix] Update get_trtllm_bench_build_command to handle batch size and tokens (NVIDIA#6313)
Signed-off-by: Venky Ganesh <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent e66060d commit 285ae6e

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/integration/defs/perf/test_perf.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,6 @@ def get_trtllm_bench_model(self):
998998

999999
def get_trtllm_bench_build_command(self, engine_dir) -> list:
10001000
model_dir = self.get_trtllm_bench_model()
1001-
dataset_path = os.path.join(engine_dir, "synthetic_data.json")
10021001
if model_dir == "":
10031002
pytest.skip("Model Name is not supported by trtllm-bench")
10041003
model_name = self._config.model_name
@@ -1008,13 +1007,19 @@ def get_trtllm_bench_build_command(self, engine_dir) -> list:
10081007
build_cmd = [
10091008
self._build_script, f"--log_level=info",
10101009
f"--workspace={engine_dir}", f"--model={hf_model_name}",
1011-
f"--model_path={model_dir}", "build", f"--dataset={dataset_path}",
1010+
f"--model_path={model_dir}", "build",
10121011
f"--tp_size={self._config.tp_size}",
10131012
f"--pp_size={self._config.pp_size}"
10141013
]
10151014
max_seq_len = max(self._config.input_lens) + max(
10161015
self._config.output_lens)
10171016
build_cmd.append(f"--max_seq_len={max_seq_len}")
1017+
# Add max_batch_size and max_num_tokens to ensure build matches runtime configuration
1018+
# Note: trtllm-bench requires both to be specified together (option group constraint)
1019+
assert self._config.max_batch_size > 0, f"max_batch_size must be > 0, got {self._config.max_batch_size}"
1020+
assert self._config.max_num_tokens > 0, f"max_num_tokens must be > 0, got {self._config.max_num_tokens}"
1021+
build_cmd.append(f"--max_batch_size={self._config.max_batch_size}")
1022+
build_cmd.append(f"--max_num_tokens={self._config.max_num_tokens}")
10181023
if self._config.quantization:
10191024
build_cmd.append(
10201025
f"--quantization={self._config.quantization.upper()}")

0 commit comments

Comments
 (0)