Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9171f88
Remove deprecated lora args from BaseLlmArgs, using peft_cache_config…
amitz-nv Jul 29, 2025
07cde29
Enabled use of LoraConfig in TRT_python flow, added tests of expected…
amitz-nv Jul 29, 2025
eabe716
Improve comments in tests
amitz-nv Jul 29, 2025
d1a896f
Correct mistake in PeftCacheConfig.num_device_module_layer description
amitz-nv Jul 29, 2025
e90872a
Add validation of unsupported field in peft cache manager
amitz-nv Jul 29, 2025
7e4e37c
Fix docstring line length
amitz-nv Jul 29, 2025
004eaf9
Fix validate_peft_cache_config
amitz-nv Jul 29, 2025
1afafa7
Fix validate_peft_cache_config formatting
amitz-nv Jul 29, 2025
c486af2
Fix lora_prefetch_dir description and 'unsupported warning' message, …
amitz-nv Jul 29, 2025
138c4b1
Fix tests to configure lora cache size by number of adapters for test…
amitz-nv Jul 29, 2025
e26ca0a
Fix tests to API update - use LoraConfig instead of base LLM args for…
amitz-nv Jul 29, 2025
ef99dd2
Fix tests to explicitly configure lora_config's max_loras and max_cpu…
amitz-nv Jul 29, 2025
797715e
Define default values in PeftCacheConfig model class for device_cache…
amitz-nv Jul 29, 2025
53b4233
Add default value to description
amitz-nv Jul 29, 2025
0d51a80
Fix PeftCacheConfig.create_from_pybind after changing python fields t…
amitz-nv Jul 29, 2025
e0fcbeb
Fix examples/llm-api/llm_multilora.py - use one LoraConfig
amitz-nv Jul 29, 2025
61a994b
Fix examples/llm-api/llm_multilora.py to not use BuildConfig that's i…
amitz-nv Jul 29, 2025
191a0ed
Changed create_from_pybind method to be a more generic classmethod in…
amitz-nv Jul 29, 2025
8cca194
Minor docstring fix
amitz-nv Jul 29, 2025
391d0f9
Fix rename
amitz-nv Jul 29, 2025
bce06ad
Fix test_ptp_quickstart_multimodal_phi4mm - for stability set lora ca…
amitz-nv Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/llm-api/llm_multilora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from tensorrt_llm import LLM
from tensorrt_llm.executor import LoRARequest
from tensorrt_llm.llmapi import BuildConfig
from tensorrt_llm.lora_manager import LoraConfig


Expand All @@ -19,12 +18,12 @@ def main():

# Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
# This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
build_config = BuildConfig()
build_config.lora_config = LoraConfig(lora_dir=[lora_dir1])
lora_config = LoraConfig(lora_dir=[lora_dir1],
max_lora_rank=64,
max_loras=3,
max_cpu_loras=3)
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
enable_lora=True,
max_lora_rank=64,
build_config=build_config)
lora_config=lora_config)

# Sample prompts
prompts = [
Expand Down
3 changes: 3 additions & 0 deletions examples/llm-api/quickstart_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def main():
models_module = importlib.import_module('tensorrt_llm._torch.models')
model_class = getattr(models_module, args.auto_model_name)
lora_config = model_class.lora_config(args.model_dir)
# For stability - explicitly set the LoRA GPU cache & CPU cache to have space for 2 adapters
lora_config.max_loras = 2
lora_config.max_cpu_loras = 2

llm, sampling_params = setup_llm(args, lora_config=lora_config)

Expand Down
8 changes: 4 additions & 4 deletions tensorrt_llm/_torch/models/modeling_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,16 @@ def lora_request(num_requests: int, modality: str, base_model_dir: str):
if modality == "image" or modality == "image_audio":
lora_request = [
LoRARequest(
lora_name=f"vision-lora-{i}",
lora_int_id=i,
lora_name="vision-lora",
lora_int_id=0,
lora_path=f"{base_model_dir}/vision-lora",
) for i in range(num_requests)
]
elif modality == "audio":
lora_request = [
LoRARequest(
lora_name=f"speech-lora-{i}",
lora_int_id=i,
lora_name="speech-lora",
lora_int_id=1,
lora_path=f"{base_model_dir}/speech-lora",
) for i in range(num_requests)
]
Expand Down
16 changes: 11 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
Expand Down Expand Up @@ -481,12 +482,17 @@ def create_py_executor_instance(
num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)

executor_config.peft_cache_config = trtllm.PeftCacheConfig(
num_device_module_layer=max_lora_rank * num_lora_modules *
lora_config.max_loras,
num_host_module_layer=max_lora_rank * num_lora_modules *
lora_config.max_cpu_loras,
peft_cache_config_model = PeftCacheConfig.from_pybind(
executor_config.peft_cache_config
) if executor_config.peft_cache_config is not None else PeftCacheConfig(
)
if lora_config.max_loras is not None:
peft_cache_config_model.num_device_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_loras
if lora_config.max_cpu_loras is not None:
peft_cache_config_model.num_host_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
executor_config.peft_cache_config = peft_cache_config_model._to_pybind()

from tensorrt_llm.bindings import WorldConfig
world_config = WorldConfig(
Expand Down
34 changes: 25 additions & 9 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from ..logger import logger
from ..sampling_params import SamplingParams
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror,
TorchLlmArgs, TrtLlmArgs)
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
PybindMirror, TorchLlmArgs, TrtLlmArgs)
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
from .mpi_session import MpiPoolSession, external_mpi_comm_available
Expand Down Expand Up @@ -807,19 +807,35 @@ def _build_model(self):
if self.args.peft_cache_config is not None:
self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
self.args.peft_cache_config)
elif self.args.build_config.plugin_config.lora_plugin:

lora_config = None
if self.args.build_config.plugin_config.lora_plugin:
engine_config = EngineConfig.from_json_file(self._engine_dir /
"config.json")
lora_config = engine_config.build_config.lora_config
if self.args.lora_config is not None:
logger.info(
"Overriding lora_config from engine with lora_config from LLM args"
)
lora_config = self.args.lora_config

max_lora_rank = lora_config.max_lora_rank
num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
self._executor_config.peft_cache_config = tllm.PeftCacheConfig(
num_device_module_layer=max_lora_rank * num_lora_modules *
self.args.max_loras,
num_host_module_layer=max_lora_rank * num_lora_modules *
self.args.max_cpu_loras,

peft_cache_config_model = PeftCacheConfig.from_pybind(
self._executor_config.peft_cache_config
) if self._executor_config.peft_cache_config is not None else PeftCacheConfig(
)
if lora_config.max_loras is not None:
peft_cache_config_model.num_device_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_loras
if lora_config.max_cpu_loras is not None:
peft_cache_config_model.num_host_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
self._executor_config.peft_cache_config = peft_cache_config_model._to_pybind(
)

if self.args.decoding_config is not None:
self._executor_config.decoding_config = self.args.decoding_config
if self.args.guided_decoding_backend == 'xgrammar':
Expand Down Expand Up @@ -860,7 +876,7 @@ def _build_model(self):
postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,
),
is_llm_executor=True,
lora_config=self.args.lora_config)
lora_config=lora_config)


@append_docstring(TORCH_LLM_DOCSTRING)
Expand Down
125 changes: 82 additions & 43 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import json
import math
import os
import types
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
TypeAlias, Union)
Type, TypeAlias, TypeVar, Union, get_args, get_origin)

import torch
import yaml
Expand Down Expand Up @@ -61,6 +62,8 @@

# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import

TypeBaseModel = TypeVar("T", bound=BaseModel)


def Field(default: Any = ...,
*,
Expand Down Expand Up @@ -598,6 +601,62 @@ def pybind_equals(obj0, obj1):
return False
return True

@classmethod
def from_pybind(cls: Type[TypeBaseModel],
pybind_instance: "PybindMirror") -> TypeBaseModel:
"""Construct an instance of the given class from the fields in the given
pybind class instance.

Args:
cls: Type of the class to construct, must be a subclass of pydantic
BaseModel
pybind_instance: Instance of the pybind class to construct from its
fields

Notes:
When a field value is None in the pybind class, but it's not
optional and has a default value in the BaseModel class, it would
get the default value defined in the BaseModel class.

Returns:
Instance of the given class, populated with the fields of the given
pybind instance
""" # noqa: D205
assert issubclass(cls, BaseModel)

# Some of the fields are optional in the C++ class but in python they aren't
# optional and have a default value, so copy the value from C++ instance
# only if it has a value, so otherwise the default value defined in the
# python class would be set.
def _is_optional_type(annotation: Any) -> bool:
"""Returns True if a type annotation represents an Optional type
(Optional[X]) or a Union type that includes None (Union[X, Y, None]
or X | Y | None).
""" # noqa: D205
origin = get_origin(annotation)
args = get_args(annotation)

# Union is for Optional[x]
# UnionType is for the new | operation in Python 3.10+
return (origin is Union
or origin is types.UnionType) and type(None) in args

fields_non_optional_with_default_value_in_basemodel = {
field_name
for field_name, field_info in cls.model_fields.items()
if not (_is_optional_type(field_info.annotation)
and field_info.is_required())
}

kwargs = {}
cpp_fields = PybindMirror.get_pybind_variable_fields(
type(pybind_instance))
for field_name in cpp_fields:
field_value = getattr(pybind_instance, field_name)
if field_value is not None or field_name not in fields_non_optional_with_default_value_in_basemodel:
kwargs[field_name] = field_value
return cls(**kwargs)


class PybindMirrorMeta(type(PybindMirror)):
pass
Expand Down Expand Up @@ -695,11 +754,12 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
default=0,
description=
"number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache"
)
", affects host cache size and overrides value of host_cache_size")
num_device_module_layer: int = Field(
default=0,
description=
"number of max sized 1-layer 1-module sets of weights that can be stored in host cache"
"number of max sized 1-layer 1-module sets of weights that can be stored in device cache"
", affects device cache size and overrides value of device_cache_percent"
)
optimal_adapter_size: int = Field(
default=
Expand All @@ -726,15 +786,17 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
max_pages_per_block_device: int = Field(
default=8,
description="Number of cache pages per allocation block (device)")
device_cache_percent: Optional[float] = Field(
default=None,
description="percent of memory after engine load to use for cache")
host_cache_size: Optional[int] = Field(
default=None, description="size in bytes to use for host cache")
device_cache_percent: float = Field(
default=0.02,
description=
"Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1"
)
host_cache_size: int = Field(
default=1024**3, description="size in bytes to use for host cache")
lora_prefetch_dir: Optional[str] = Field(
default=None,
description=
"folder to store the LoRA weights we hope to load during engine initialization"
"folder to store the LoRA weights we hope to load during engine initialization, currently not supported"
)

def _to_pybind(self):
Expand Down Expand Up @@ -1084,27 +1146,6 @@ class BaseLlmArgs(StrictBaseModel):
# LoRA arguments
enable_lora: bool = Field(default=False, description="Enable LoRA.")

max_lora_rank: Optional[int] = Field(
default=None,
description="The maximum LoRA rank.",
deprecated="Use lora_config.max_lora_rank instead.",
status="deprecated",
)

max_loras: int = Field(
default=4,
description="The maximum number of LoRA.",
deprecated="Use lora_config.max_loras instead.",
status="deprecated",
)

max_cpu_loras: int = Field(
default=4,
description="The maximum number of LoRA on CPU.",
deprecated="Use lora_config.max_cpu_loras instead.",
status="deprecated",
)

lora_config: Optional[LoraConfig] = Field(
default=None, description="LoRA configuration for the model.")

Expand Down Expand Up @@ -1495,10 +1536,10 @@ def validate_build_config_remaining(self):
if self.parallel_config._world_size == 1 and self.build_config:
self.build_config.plugin_config.nccl_plugin = None

if self.enable_lora and self.lora_config is None and self.backend != 'pytorch':
if self.enable_lora and self.backend != 'pytorch':
self.build_config.plugin_config.lora_plugin = 'auto'
if self.max_lora_rank is not None:
self.build_config.lora_config.max_lora_rank = self.max_lora_rank
if self.lora_config is not None:
self.build_config.lora_config.max_lora_rank = self.lora_config.max_lora_rank

if hasattr(self,
'enable_prompt_adapter') and self.enable_prompt_adapter:
Expand Down Expand Up @@ -1602,16 +1643,6 @@ def validate_speculative_config(self):
@model_validator(mode="after")
def validate_lora_config_consistency(self):
if self.lora_config:
if self.max_lora_rank is not None:
logger.warning(
"max_lora_rank is ignored when lora_config is provided.")
if self.max_loras != self.lora_config.max_loras:
logger.warning(
"max_loras is ignored when lora_config is provided.")
if self.max_cpu_loras != self.lora_config.max_cpu_loras:
logger.warning(
"max_cpu_loras is ignored when lora_config is provided.")

if len(self.lora_config.lora_dir) == 0:
# TODO [TRTLLM-5173]
logger.warning(
Expand All @@ -1638,6 +1669,14 @@ def validate_lora_config_consistency(self):
default_trtllm_modules_to_hf_modules.keys())
return self

@model_validator(mode="after")
def validate_peft_cache_config(self):
if self.peft_cache_config is not None and self.peft_cache_config.lora_prefetch_dir is not None:
raise ValueError(
f"lora_prefetch_dir was set to '{self.peft_cache_config.lora_prefetch_dir}' "
"while LoRA prefetch is not supported")
return self

def _update_plugin_config(self, key: str, value: Any):
setattr(self.build_config.plugin_config, key, value)

Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ class LoraConfig(DictConversion):
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: int = 4
max_cpu_loras: int = 4
max_loras: int | None = None
max_cpu_loras: int | None = None

def __post_init__(self):
assert self.lora_ckpt_source in ["hf", "nemo"], (
Expand Down
4 changes: 3 additions & 1 deletion tests/unittest/llmapi/apps/_test_openai_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def temp_extra_llm_api_options_file():
extra_llm_api_options_dict = {
"lora_config": {
"lora_target_modules": ['attn_q', 'attn_k', 'attn_v'],
"max_lora_rank": 8
"max_lora_rank": 8,
"max_loras": 4,
"max_cpu_loras": 4,
}
}

Expand Down
4 changes: 3 additions & 1 deletion tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def temp_extra_llm_api_options_file():
extra_llm_api_options_dict = {
"lora_config": {
"lora_target_modules": ['attn_q', 'attn_k', 'attn_v'],
"max_lora_rank": 8
"max_lora_rank": 8,
"max_loras": 4,
"max_cpu_loras": 4,
}
}

Expand Down
Loading