Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
'trtllm-refit=tensorrt_llm.commands.refit:main',
'trtllm-bench=tensorrt_llm.commands.bench:main',
'trtllm-serve=tensorrt_llm.commands.serve:main',
'trtllm-eval=tensorrt_llm.commands.eval:main'
'trtllm-eval=tensorrt_llm.commands.eval:main',
'trtllm-configure=tensorrt_llm.commands.configure:main'
],
},
scripts=['tensorrt_llm/llmapi/trtllm-llmapi-launch'],
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/commands/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from tensorrt_llm.configure.cli import TRTLLMConfigure


def main():
TRTLLMConfigure().run()


if __name__ == "__main__":
main()
Empty file.
92 changes: 92 additions & 0 deletions tensorrt_llm/configure/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from pathlib import Path
from typing import Optional, Union

import yaml
from pydantic import AliasChoices, BaseModel, Field, model_validator
from pydantic_settings import BaseSettings, CliSubCommand, SettingsConfigDict, get_subcommand

from tensorrt_llm.configure.profile import InferenceMaxProfile, ThroughputLatencySLAProfile
from tensorrt_llm.logger import logger


class CommonOptions(BaseModel):
"""Common options for all subcommands of the trtllm-configure CLI tool."""

output: Optional[Path] = Field(
default=Path("config.yaml"),
description="YAML file path where the optimized config will be written.",
validation_alias=AliasChoices("output", "o"),
)

@model_validator(mode="after")
def validate_output(self) -> "CommonOptions":
"""Verify that output file is a valid YAML file path and does not already exist."""
if self.output is not None:
if self.output.suffix != ".yaml":
raise ValueError(f"Output file must be a YAML file. Got '{self.output}'.")
if self.output.exists():
raise ValueError(f"Output file '{self.output}' already exists.")
Copy link
Collaborator

@venkywonka venkywonka Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there's an easy way to parametrize and expose a --force-overwrite or something, it would help in avoiding a potential user gripe of having to explicitly delete the file everytime - but P1 feel free to ignore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to change to overwrite by default, since it seems like most users would probably have this gripe. LMK what you think

return self


class InferenceMaxSubCommand(InferenceMaxProfile, CommonOptions):
"""Optimize TensorRT LLM for an InferenceMax benchmark workload with a specific number of concurrent requests."""


class ThroughputLatencySLASubCommand(ThroughputLatencySLAProfile, CommonOptions):
"""Optimize TensorRT LLM to meet a throughput and latency SLA."""


TRTLLMConfigureSubCommand = Union[InferenceMaxSubCommand, ThroughputLatencySLASubCommand]


class TRTLLMConfigure(BaseSettings):
# NOTE: the docstring below is used to generate the CLI help message
"""The trtllm-configure CLI tool allows you to optimize the configuration of TensorRT LLM for your specific
inference scenario.
""" # noqa: D205

model_config = SettingsConfigDict(
cli_parse_args=True,
cli_prog_name="trtllm-configure",
cli_enforce_required=True, # Make required fields enforced at CLI level
cli_implicit_flags=True, # Boolean fields will be exposed as e.g. --flag and --no-flag
cli_avoid_json=True, # Do not expose JSON string options for nested models
)

inferencemax: CliSubCommand[InferenceMaxSubCommand] = Field(
description=InferenceMaxSubCommand.__doc__
)
# TODO: add support for throughput/latency SLA subcommand
# throughput_latency: CliSubCommand[ThroughputLatencySLASubCommand] = Field(
# description=ThroughputLatencySLASubCommand.__doc__
# )

def run(self) -> None:
"""Main entrypoint for the trtllm-configure CLI tool."""
subcommand: TRTLLMConfigureSubCommand = get_subcommand(self)

config = subcommand.get_config()

# exclude_unset and exclude_default are explicitly used to avoid including default values
config_dict = config.model_dump(exclude_unset=True, exclude_default=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we want to exclude unset and default? I can think of a scenario where the defaults change in newer versions and suddenly you get perf changes on the same config, compared to when you ran it on the version where it was originally generated. However, being overly explicit also has it's problems, so I'm not sure what's the best option here. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. In my second pass at this, I actually changed get_config to return a dict and removed all references to LlmArgs within the trtllm-configure code itself. We should still test that trtllm-configure produces configs compatible with LlmArgs in CI, but I opted to make this change because we eventually want to be able to decouple and use this tool without TRTLLM installed, and at that point importing LlmArgs won't be possible anyway.

re: perf changes due to defaults changing, I think that once we add versioning as an explicit part of the user's constraints, we should be returning different configs for each version and it should solve this problem. WDYT?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense.
Thanks!

logger.info(f"Optimized config: \n\n{yaml.safe_dump(config_dict)}")

if subcommand.output is None:
logger.info(
"No output file specified. To write the optimized config to a file, use the --output / -o flag."
)
else:
with open(subcommand.output, "w") as f:
f.write(yaml.safe_dump(config_dict))
logger.info(f"Optimized config written to {subcommand.output}")
logger.info("To serve the model with optimized settings, run the following command:")
logger.info(f" trtllm-serve {subcommand.model} --config {subcommand.output}")


def main():
TRTLLMConfigure().run()


if __name__ == "__main__":
main()
33 changes: 33 additions & 0 deletions tensorrt_llm/configure/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod

from pydantic import BaseModel, validate_call

from tensorrt_llm.configure.scenario import BenchmarkScenario, ThroughputLatencySLAScenario
from tensorrt_llm.llmapi.llm_args import LlmArgs


class BaseProfile(BaseModel, ABC):
"""Base class for all profiles.

A profile defines a particular strategy used to find an optimized config for a given scenario
(e.g. database lookup, heuristics, etc.)

Each profile is compatible with a specific scenario type.
"""

@abstractmethod
def get_config(self) -> LlmArgs: ...


class InferenceMaxProfile(BaseProfile, BenchmarkScenario):
@validate_call
def get_config(self) -> LlmArgs:
# TODO: add logic to retrieve optimal recipe from database
return LlmArgs()


class ThroughputLatencySLAProfile(BaseProfile, ThroughputLatencySLAScenario):
@validate_call
def get_config(self) -> LlmArgs:
# TODO: add logic to retrieve optimal recipe
return LlmArgs()
75 changes: 75 additions & 0 deletions tensorrt_llm/configure/scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from enum import StrEnum
from typing import Optional

from pydantic import (
AliasChoices,
BaseModel,
Field,
NonNegativeInt,
PositiveFloat,
PositiveInt,
model_validator,
)


class GPU(StrEnum):
GB200 = "GB200"
H200_SXM = "H200_SXM"


class BaseScenario(BaseModel):
"""Base class for all scenarios containing common fields.

A scenario fully defines a specific inference workload and the goals for optimization
(e.g. SLA targets).
"""

model: str = Field(description="HuggingFace ID of the model being deployed")
gpu: GPU = Field(description="GPU SKU used in the deployment")
num_gpus: int = Field(description="Number of GPUs available in the deployment")


class BenchmarkScenario(BaseScenario):
isl: NonNegativeInt = Field(
description="Target input sequence length",
validation_alias=AliasChoices("isl", "target_isl", "input_sequence_length"),
)
osl: NonNegativeInt = Field(
description="Target output sequence length",
validation_alias=AliasChoices("osl", "target_osl", "output_sequence_length"),
)
concurrency: PositiveInt = Field(
description="Target number of concurrent requests",
validation_alias=AliasChoices("concurrency", "target_concurrency"),
)
# TODO: make this optional and add logic to choose best parallelization mapping automatically
tensor_parallel_size: PositiveInt = Field(
description="Specific tensor parallel size that should be used",
validation_alias=AliasChoices("tensor_parallel_size", "tp"),
)


class ThroughputLatencySLAScenario(BaseScenario):
tps_per_gpu: PositiveFloat = Field(
description="Target throughput per GPU in tokens per second",
validation_alias=AliasChoices("tps_per_gpu", "target_tps_per_gpu", "min_tps_per_gpu"),
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shouldn't all be optional and at least one specified? for example, low concurrency workloads with a fixed amount of gpus where a user is only interested in TTFT and ITL

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, yup you're right. Fixed

tps_per_user: Optional[PositiveFloat] = Field(
default=None,
description="Target throughput per user in tokens per second. Mutually exclusive with target time to first "
"token.",
validation_alias=AliasChoices("tps_per_user", "target_tps_per_user", "min_tps_per_user"),
)
ttft: Optional[PositiveFloat] = Field(
default=None,
description="Target time to first token in seconds. Mutually exclusive with target throughput per user.",
validation_alias=AliasChoices("ttft", "target_ttft", "max_ttft"),
)

@model_validator(mode="after")
def validate_mutually_exclusive_latency_sla(self) -> "ThroughputLatencySLAScenario":
if self.tps_per_user is not None and self.ttft is not None:
raise ValueError(
"Target throughput per user and target time to first token cannot be specified together."
)
return self
38 changes: 38 additions & 0 deletions tests/unittest/configure/test_configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path
from unittest.mock import patch

import yaml

from tensorrt_llm.configure.cli import InferenceMaxSubCommand, TRTLLMConfigure
from tensorrt_llm.llmapi.llm_args import LlmArgs


def test_trtllm_configure_subcommand_basic(tmp_path: Path):
output_path = tmp_path / "test_config.yaml"

mock_config = LlmArgs()
mock_config.kv_cache_free_gpu_memory_fraction = 0.9

cmd = InferenceMaxSubCommand(
model="meta-llama/Llama-3.1-8B",
gpu="H200_SXM",
num_gpus=1,
isl=1000,
osl=2000,
concurrency=64,
tensor_parallel_size=1,
output=output_path,
)

trtllm_configure = TRTLLMConfigure(inferencemax=cmd)

# Mock get_config to return our mock config
with patch.object(cmd, "get_config", return_value=mock_config):
trtllm_configure.run()

assert output_path.exists()
with open(output_path, "r") as f:
loaded_config = yaml.safe_load(f)

assert "kv_cache_free_gpu_memory_fraction" in loaded_config
assert loaded_config["kv_cache_free_gpu_memory_fraction"] == 0.9
Loading