Skip to content

Commit e3fe37e

Browse files
[TRTC-1921][feat] Add trtllm-configure CLI tool and scenario/profile schemas
Signed-off-by: Anish Shanbhag <[email protected]>
1 parent f1d637e commit e3fe37e

File tree

7 files changed

+249
-1
lines changed

7 files changed

+249
-1
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
283283
'trtllm-refit=tensorrt_llm.commands.refit:main',
284284
'trtllm-bench=tensorrt_llm.commands.bench:main',
285285
'trtllm-serve=tensorrt_llm.commands.serve:main',
286-
'trtllm-eval=tensorrt_llm.commands.eval:main'
286+
'trtllm-eval=tensorrt_llm.commands.eval:main',
287+
'trtllm-configure=tensorrt_llm.commands.configure:main'
287288
],
288289
},
289290
scripts=['tensorrt_llm/llmapi/trtllm-llmapi-launch'],

tensorrt_llm/commands/configure.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from tensorrt_llm.configure.cli import TRTLLMConfigure
2+
3+
4+
def main():
5+
TRTLLMConfigure().run()
6+
7+
8+
if __name__ == "__main__":
9+
main()

tensorrt_llm/configure/__init__.py

Whitespace-only changes.

tensorrt_llm/configure/cli.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from pathlib import Path
2+
from typing import Optional, Union
3+
4+
import yaml
5+
from pydantic import AliasChoices, BaseModel, Field, model_validator
6+
from pydantic_settings import BaseSettings, CliSubCommand, SettingsConfigDict, get_subcommand
7+
8+
from tensorrt_llm.configure.profile import InferenceMaxProfile, ThroughputLatencySLAProfile
9+
from tensorrt_llm.logger import logger
10+
11+
12+
class CommonOptions(BaseModel):
13+
"""Common options for all subcommands of the trtllm-configure CLI tool."""
14+
15+
output: Optional[Path] = Field(
16+
default=Path("config.yaml"),
17+
description="YAML file path where the optimized config will be written.",
18+
validation_alias=AliasChoices("output", "o"),
19+
)
20+
21+
@model_validator(mode="after")
22+
def validate_output(self) -> "CommonOptions":
23+
"""Verify that output file is a valid YAML file path and does not already exist."""
24+
if self.output is not None:
25+
if self.output.suffix != ".yaml":
26+
raise ValueError(f"Output file must be a YAML file. Got '{self.output}'.")
27+
if self.output.exists():
28+
raise ValueError(f"Output file '{self.output}' already exists.")
29+
return self
30+
31+
32+
class InferenceMaxSubCommand(InferenceMaxProfile, CommonOptions):
33+
"""Optimize TensorRT LLM for an InferenceMax benchmark workload with a specific number of concurrent requests."""
34+
35+
36+
class ThroughputLatencySLASubCommand(ThroughputLatencySLAProfile, CommonOptions):
37+
"""Optimize TensorRT LLM to meet a throughput and latency SLA."""
38+
39+
40+
TRTLLMConfigureSubCommand = Union[InferenceMaxSubCommand, ThroughputLatencySLASubCommand]
41+
42+
43+
class TRTLLMConfigure(BaseSettings):
44+
# NOTE: the docstring below is used to generate the CLI help message
45+
"""The trtllm-configure CLI tool allows you to optimize the configuration of TensorRT LLM for your specific
46+
inference scenario.
47+
""" # noqa: D205
48+
49+
model_config = SettingsConfigDict(
50+
cli_parse_args=True,
51+
cli_prog_name="trtllm-configure",
52+
cli_enforce_required=True, # Make required fields enforced at CLI level
53+
cli_implicit_flags=True, # Boolean fields will be exposed as e.g. --flag and --no-flag
54+
cli_avoid_json=True, # Do not expose JSON string options for nested models
55+
)
56+
57+
inferencemax: CliSubCommand[InferenceMaxSubCommand] = Field(
58+
description=InferenceMaxSubCommand.__doc__
59+
)
60+
# TODO: add support for throughput/latency SLA subcommand
61+
# throughput_latency: CliSubCommand[ThroughputLatencySLASubCommand] = Field(
62+
# description=ThroughputLatencySLASubCommand.__doc__
63+
# )
64+
65+
def run(self) -> None:
66+
"""Main entrypoint for the trtllm-configure CLI tool."""
67+
subcommand: TRTLLMConfigureSubCommand = get_subcommand(self)
68+
69+
config = subcommand.get_config()
70+
71+
# exclude_unset and exclude_default are explicitly used to avoid including default values
72+
config_dict = config.model_dump(exclude_unset=True, exclude_default=True)
73+
logger.info(f"Optimized config: \n\n{yaml.safe_dump(config_dict)}")
74+
75+
if subcommand.output is None:
76+
logger.info(
77+
"No output file specified. To write the optimized config to a file, use the --output / -o flag."
78+
)
79+
else:
80+
with open(subcommand.output, "w") as f:
81+
f.write(yaml.safe_dump(config_dict))
82+
logger.info(f"Optimized config written to {subcommand.output}")
83+
logger.info("To serve the model with optimized settings, run the following command:")
84+
logger.info(f" trtllm-serve {subcommand.model} --config {subcommand.output}")
85+
86+
87+
def main():
88+
TRTLLMConfigure().run()
89+
90+
91+
if __name__ == "__main__":
92+
main()

tensorrt_llm/configure/profile.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from abc import ABC, abstractmethod
2+
3+
from pydantic import BaseModel, validate_call
4+
5+
from tensorrt_llm.configure.scenario import BenchmarkScenario, ThroughputLatencySLAScenario
6+
from tensorrt_llm.llmapi.llm_args import LlmArgs
7+
8+
9+
class BaseProfile(BaseModel, ABC):
10+
"""Base class for all profiles.
11+
12+
A profile defines a particular strategy used to find an optimized config for a given scenario
13+
(e.g. database lookup, heuristics, etc.)
14+
15+
Each profile is compatible with a specific scenario type.
16+
"""
17+
18+
@abstractmethod
19+
def get_config(self) -> LlmArgs: ...
20+
21+
22+
class InferenceMaxProfile(BaseProfile, BenchmarkScenario):
23+
@validate_call
24+
def get_config(self) -> LlmArgs:
25+
# TODO: add logic to retrieve optimal recipe from database
26+
return LlmArgs()
27+
28+
29+
class ThroughputLatencySLAProfile(BaseProfile, ThroughputLatencySLAScenario):
30+
@validate_call
31+
def get_config(self) -> LlmArgs:
32+
# TODO: add logic to retrieve optimal recipe
33+
return LlmArgs()

tensorrt_llm/configure/scenario.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from enum import StrEnum
2+
from typing import Optional
3+
4+
from pydantic import (
5+
AliasChoices,
6+
BaseModel,
7+
Field,
8+
NonNegativeInt,
9+
PositiveFloat,
10+
PositiveInt,
11+
model_validator,
12+
)
13+
14+
15+
class GPU(StrEnum):
16+
GB200 = "GB200"
17+
H200_SXM = "H200_SXM"
18+
19+
20+
class BaseScenario(BaseModel):
21+
"""Base class for all scenarios containing common fields.
22+
23+
A scenario fully defines a specific inference workload and the goals for optimization
24+
(e.g. SLA targets).
25+
"""
26+
27+
model: str = Field(description="HuggingFace ID of the model being deployed")
28+
gpu: GPU = Field(description="GPU SKU used in the deployment")
29+
num_gpus: int = Field(description="Number of GPUs available in the deployment")
30+
31+
32+
class BenchmarkScenario(BaseScenario):
33+
isl: NonNegativeInt = Field(
34+
description="Target input sequence length",
35+
validation_alias=AliasChoices("isl", "target_isl", "input_sequence_length"),
36+
)
37+
osl: NonNegativeInt = Field(
38+
description="Target output sequence length",
39+
validation_alias=AliasChoices("osl", "target_osl", "output_sequence_length"),
40+
)
41+
concurrency: PositiveInt = Field(
42+
description="Target number of concurrent requests",
43+
validation_alias=AliasChoices("concurrency", "target_concurrency"),
44+
)
45+
# TODO: make this optional and add logic to choose best parallelization mapping automatically
46+
tensor_parallel_size: PositiveInt = Field(
47+
description="Specific tensor parallel size that should be used",
48+
validation_alias=AliasChoices("tensor_parallel_size", "tp"),
49+
)
50+
51+
52+
class ThroughputLatencySLAScenario(BaseScenario):
53+
tps_per_gpu: PositiveFloat = Field(
54+
description="Target throughput per GPU in tokens per second",
55+
validation_alias=AliasChoices("tps_per_gpu", "target_tps_per_gpu", "min_tps_per_gpu"),
56+
)
57+
tps_per_user: Optional[PositiveFloat] = Field(
58+
default=None,
59+
description="Target throughput per user in tokens per second. Mutually exclusive with target time to first "
60+
"token.",
61+
validation_alias=AliasChoices("tps_per_user", "target_tps_per_user", "min_tps_per_user"),
62+
)
63+
ttft: Optional[PositiveFloat] = Field(
64+
default=None,
65+
description="Target time to first token in seconds. Mutually exclusive with target throughput per user.",
66+
validation_alias=AliasChoices("ttft", "target_ttft", "max_ttft"),
67+
)
68+
69+
@model_validator(mode="after")
70+
def validate_mutually_exclusive_latency_sla(self) -> "ThroughputLatencySLAScenario":
71+
if self.tps_per_user is not None and self.ttft is not None:
72+
raise ValueError(
73+
"Target throughput per user and target time to first token cannot be specified together."
74+
)
75+
return self
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from pathlib import Path
2+
from unittest.mock import patch
3+
4+
import yaml
5+
6+
from tensorrt_llm.configure.cli import InferenceMaxSubCommand, TRTLLMConfigure
7+
from tensorrt_llm.llmapi.llm_args import LlmArgs
8+
9+
10+
def test_trtllm_configure_subcommand_basic(tmp_path: Path):
11+
output_path = tmp_path / "test_config.yaml"
12+
13+
mock_config = LlmArgs()
14+
mock_config.kv_cache_free_gpu_memory_fraction = 0.9
15+
16+
cmd = InferenceMaxSubCommand(
17+
model="meta-llama/Llama-3.1-8B",
18+
gpu="H200_SXM",
19+
num_gpus=1,
20+
isl=1000,
21+
osl=2000,
22+
concurrency=64,
23+
tensor_parallel_size=1,
24+
output=output_path,
25+
)
26+
27+
trtllm_configure = TRTLLMConfigure(inferencemax=cmd)
28+
29+
# Mock get_config to return our mock config
30+
with patch.object(cmd, "get_config", return_value=mock_config):
31+
trtllm_configure.run()
32+
33+
assert output_path.exists()
34+
with open(output_path, "r") as f:
35+
loaded_config = yaml.safe_load(f)
36+
37+
assert "kv_cache_free_gpu_memory_fraction" in loaded_config
38+
assert loaded_config["kv_cache_free_gpu_memory_fraction"] == 0.9

0 commit comments

Comments
 (0)