-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTC-1921][feat] Add trtllm-configure CLI tool and constraints/profiles schemas #9160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
e3fe37e
4788dab
478ffb2
d7cb3ff
01fa195
841807a
8527251
906c665
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| from tensorrt_llm.configure.cli import TRTLLMConfigure | ||
|
|
||
|
|
||
| def main(): | ||
| TRTLLMConfigure().run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| from pathlib import Path | ||
| from typing import Optional, Union | ||
|
|
||
| import yaml | ||
| from pydantic import AliasChoices, BaseModel, Field, model_validator | ||
| from pydantic_settings import BaseSettings, CliSubCommand, SettingsConfigDict, get_subcommand | ||
|
|
||
| from tensorrt_llm.configure.profile import InferenceMaxProfile, ThroughputLatencySLAProfile | ||
| from tensorrt_llm.logger import logger | ||
|
|
||
|
|
||
| class CommonOptions(BaseModel): | ||
| """Common options for all subcommands of the trtllm-configure CLI tool.""" | ||
|
|
||
| output: Optional[Path] = Field( | ||
| default=Path("config.yaml"), | ||
| description="YAML file path where the optimized config will be written.", | ||
| validation_alias=AliasChoices("output", "o"), | ||
| ) | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_output(self) -> "CommonOptions": | ||
| """Verify that output file is a valid YAML file path and does not already exist.""" | ||
| if self.output is not None: | ||
| if self.output.suffix != ".yaml": | ||
| raise ValueError(f"Output file must be a YAML file. Got '{self.output}'.") | ||
| if self.output.exists(): | ||
| raise ValueError(f"Output file '{self.output}' already exists.") | ||
|
||
| return self | ||
|
|
||
|
|
||
| class InferenceMaxSubCommand(InferenceMaxProfile, CommonOptions): | ||
| """Optimize TensorRT LLM for an InferenceMax benchmark workload with a specific number of concurrent requests.""" | ||
|
|
||
|
|
||
| class ThroughputLatencySLASubCommand(ThroughputLatencySLAProfile, CommonOptions): | ||
| """Optimize TensorRT LLM to meet a throughput and latency SLA.""" | ||
|
|
||
|
|
||
| TRTLLMConfigureSubCommand = Union[InferenceMaxSubCommand, ThroughputLatencySLASubCommand] | ||
|
|
||
|
|
||
| class TRTLLMConfigure(BaseSettings): | ||
| # NOTE: the docstring below is used to generate the CLI help message | ||
| """The trtllm-configure CLI tool allows you to optimize the configuration of TensorRT LLM for your specific | ||
| inference scenario. | ||
| """ # noqa: D205 | ||
anish-shanbhag marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| model_config = SettingsConfigDict( | ||
| cli_parse_args=True, | ||
| cli_prog_name="trtllm-configure", | ||
| cli_enforce_required=True, # Make required fields enforced at CLI level | ||
| cli_implicit_flags=True, # Boolean fields will be exposed as e.g. --flag and --no-flag | ||
| cli_avoid_json=True, # Do not expose JSON string options for nested models | ||
| ) | ||
|
|
||
| inferencemax: CliSubCommand[InferenceMaxSubCommand] = Field( | ||
| description=InferenceMaxSubCommand.__doc__ | ||
| ) | ||
| # TODO: add support for throughput/latency SLA subcommand | ||
| # throughput_latency: CliSubCommand[ThroughputLatencySLASubCommand] = Field( | ||
| # description=ThroughputLatencySLASubCommand.__doc__ | ||
| # ) | ||
|
|
||
| def run(self) -> None: | ||
| """Main entrypoint for the trtllm-configure CLI tool.""" | ||
| subcommand: TRTLLMConfigureSubCommand = get_subcommand(self) | ||
|
|
||
| config = subcommand.get_config() | ||
|
|
||
| # exclude_unset and exclude_default are explicitly used to avoid including default values | ||
| config_dict = config.model_dump(exclude_unset=True, exclude_default=True) | ||
|
||
| logger.info(f"Optimized config: \n\n{yaml.safe_dump(config_dict)}") | ||
|
|
||
| if subcommand.output is None: | ||
| logger.info( | ||
| "No output file specified. To write the optimized config to a file, use the --output / -o flag." | ||
| ) | ||
| else: | ||
| with open(subcommand.output, "w") as f: | ||
| f.write(yaml.safe_dump(config_dict)) | ||
| logger.info(f"Optimized config written to {subcommand.output}") | ||
| logger.info("To serve the model with optimized settings, run the following command:") | ||
| logger.info(f" trtllm-serve {subcommand.model} --config {subcommand.output}") | ||
|
|
||
|
|
||
| def main(): | ||
| TRTLLMConfigure().run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from abc import ABC, abstractmethod | ||
|
|
||
| from pydantic import BaseModel, validate_call | ||
|
|
||
| from tensorrt_llm.configure.scenario import BenchmarkScenario, ThroughputLatencySLAScenario | ||
| from tensorrt_llm.llmapi.llm_args import LlmArgs | ||
|
|
||
|
|
||
| class BaseProfile(BaseModel, ABC): | ||
| """Base class for all profiles. | ||
|
|
||
| A profile defines a particular strategy used to find an optimized config for a given scenario | ||
| (e.g. database lookup, heuristics, etc.) | ||
|
|
||
| Each profile is compatible with a specific scenario type. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def get_config(self) -> LlmArgs: ... | ||
|
|
||
|
|
||
| class InferenceMaxProfile(BaseProfile, BenchmarkScenario): | ||
| @validate_call | ||
| def get_config(self) -> LlmArgs: | ||
| # TODO: add logic to retrieve optimal recipe from database | ||
| return LlmArgs() | ||
|
|
||
|
|
||
| class ThroughputLatencySLAProfile(BaseProfile, ThroughputLatencySLAScenario): | ||
| @validate_call | ||
| def get_config(self) -> LlmArgs: | ||
| # TODO: add logic to retrieve optimal recipe | ||
| return LlmArgs() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| from enum import StrEnum | ||
| from typing import Optional | ||
|
|
||
| from pydantic import ( | ||
| AliasChoices, | ||
| BaseModel, | ||
| Field, | ||
| NonNegativeInt, | ||
| PositiveFloat, | ||
| PositiveInt, | ||
| model_validator, | ||
| ) | ||
|
|
||
|
|
||
| class GPU(StrEnum): | ||
| GB200 = "GB200" | ||
| H200_SXM = "H200_SXM" | ||
|
|
||
|
|
||
| class BaseScenario(BaseModel): | ||
| """Base class for all scenarios containing common fields. | ||
|
|
||
| A scenario fully defines a specific inference workload and the goals for optimization | ||
| (e.g. SLA targets). | ||
| """ | ||
|
|
||
| model: str = Field(description="HuggingFace ID of the model being deployed") | ||
| gpu: GPU = Field(description="GPU SKU used in the deployment") | ||
| num_gpus: int = Field(description="Number of GPUs available in the deployment") | ||
|
|
||
|
|
||
| class BenchmarkScenario(BaseScenario): | ||
| isl: NonNegativeInt = Field( | ||
| description="Target input sequence length", | ||
| validation_alias=AliasChoices("isl", "target_isl", "input_sequence_length"), | ||
| ) | ||
| osl: NonNegativeInt = Field( | ||
| description="Target output sequence length", | ||
| validation_alias=AliasChoices("osl", "target_osl", "output_sequence_length"), | ||
| ) | ||
| concurrency: PositiveInt = Field( | ||
| description="Target number of concurrent requests", | ||
| validation_alias=AliasChoices("concurrency", "target_concurrency"), | ||
| ) | ||
| # TODO: make this optional and add logic to choose best parallelization mapping automatically | ||
| tensor_parallel_size: PositiveInt = Field( | ||
| description="Specific tensor parallel size that should be used", | ||
| validation_alias=AliasChoices("tensor_parallel_size", "tp"), | ||
| ) | ||
|
|
||
|
|
||
| class ThroughputLatencySLAScenario(BaseScenario): | ||
| tps_per_gpu: PositiveFloat = Field( | ||
| description="Target throughput per GPU in tokens per second", | ||
| validation_alias=AliasChoices("tps_per_gpu", "target_tps_per_gpu", "min_tps_per_gpu"), | ||
| ) | ||
|
||
| tps_per_user: Optional[PositiveFloat] = Field( | ||
| default=None, | ||
| description="Target throughput per user in tokens per second. Mutually exclusive with target time to first " | ||
| "token.", | ||
| validation_alias=AliasChoices("tps_per_user", "target_tps_per_user", "min_tps_per_user"), | ||
| ) | ||
| ttft: Optional[PositiveFloat] = Field( | ||
| default=None, | ||
| description="Target time to first token in seconds. Mutually exclusive with target throughput per user.", | ||
| validation_alias=AliasChoices("ttft", "target_ttft", "max_ttft"), | ||
| ) | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_mutually_exclusive_latency_sla(self) -> "ThroughputLatencySLAScenario": | ||
| if self.tps_per_user is not None and self.ttft is not None: | ||
| raise ValueError( | ||
| "Target throughput per user and target time to first token cannot be specified together." | ||
| ) | ||
| return self | ||
anish-shanbhag marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| from pathlib import Path | ||
| from unittest.mock import patch | ||
|
|
||
| import yaml | ||
|
|
||
| from tensorrt_llm.configure.cli import InferenceMaxSubCommand, TRTLLMConfigure | ||
| from tensorrt_llm.llmapi.llm_args import LlmArgs | ||
|
|
||
|
|
||
| def test_trtllm_configure_subcommand_basic(tmp_path: Path): | ||
| output_path = tmp_path / "test_config.yaml" | ||
|
|
||
| mock_config = LlmArgs() | ||
| mock_config.kv_cache_free_gpu_memory_fraction = 0.9 | ||
|
|
||
| cmd = InferenceMaxSubCommand( | ||
| model="meta-llama/Llama-3.1-8B", | ||
| gpu="H200_SXM", | ||
| num_gpus=1, | ||
| isl=1000, | ||
| osl=2000, | ||
| concurrency=64, | ||
| tensor_parallel_size=1, | ||
| output=output_path, | ||
| ) | ||
|
|
||
| trtllm_configure = TRTLLMConfigure(inferencemax=cmd) | ||
|
|
||
| # Mock get_config to return our mock config | ||
| with patch.object(cmd, "get_config", return_value=mock_config): | ||
| trtllm_configure.run() | ||
|
|
||
| assert output_path.exists() | ||
| with open(output_path, "r") as f: | ||
| loaded_config = yaml.safe_load(f) | ||
|
|
||
| assert "kv_cache_free_gpu_memory_fraction" in loaded_config | ||
| assert loaded_config["kv_cache_free_gpu_memory_fraction"] == 0.9 |
Uh oh!
There was an error while loading. Please reload this page.