From 9171f886641f0d255ca27ba547bf59a5af504c51 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:18 +0000 Subject: [PATCH 01/21] Remove deprecated lora args from BaseLlmArgs, using peft_cache_config and allowing lora_config.max_loras and lora_config.max_cpu_loras to override it, changed their default value to None Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 16 +++++--- tensorrt_llm/llmapi/llm.py | 34 +++++++++++----- tensorrt_llm/llmapi/llm_args.py | 52 ++++++++++--------------- tensorrt_llm/lora_manager.py | 4 +- 4 files changed, 58 insertions(+), 48 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4754e693fc5..24b9fc12f0b 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -11,6 +11,7 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig +from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import (LoraConfig, get_default_trtllm_modules_to_hf_modules, @@ -481,12 +482,17 @@ def create_py_executor_instance( num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \ len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) - executor_config.peft_cache_config = trtllm.PeftCacheConfig( - num_device_module_layer=max_lora_rank * num_lora_modules * - lora_config.max_loras, - num_host_module_layer=max_lora_rank * num_lora_modules * - lora_config.max_cpu_loras, + peft_cache_config_model = PeftCacheConfig.create_from_pybind( + executor_config.peft_cache_config + ) if executor_config.peft_cache_config is not None else PeftCacheConfig( ) + if lora_config.max_loras is not None: + peft_cache_config_model.num_device_module_layer = \ + max_lora_rank * num_lora_modules * lora_config.max_loras + if lora_config.max_cpu_loras is not None: + peft_cache_config_model.num_host_module_layer = \ + max_lora_rank * num_lora_modules * lora_config.max_cpu_loras + executor_config.peft_cache_config = peft_cache_config_model._to_pybind() from tensorrt_llm.bindings import WorldConfig world_config = WorldConfig( diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 0d1a1e80201..554e1b739df 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -31,8 +31,8 @@ from ..logger import logger from ..sampling_params import SamplingParams from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING, - TRT_LLMARGS_EXPLICIT_DOCSTRING, PybindMirror, - TorchLlmArgs, TrtLlmArgs) + TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig, + PybindMirror, TorchLlmArgs, TrtLlmArgs) from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available @@ -807,19 +807,35 @@ def _build_model(self): if self.args.peft_cache_config is not None: self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( self.args.peft_cache_config) - elif self.args.build_config.plugin_config.lora_plugin: + + lora_config = None + if self.args.build_config.plugin_config.lora_plugin: engine_config = EngineConfig.from_json_file(self._engine_dir / "config.json") lora_config = engine_config.build_config.lora_config + if self.args.lora_config is not None: + logger.info( + "Overriding lora_config from engine with lora_config from LLM args" + ) + lora_config = self.args.lora_config + max_lora_rank = lora_config.max_lora_rank num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \ len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) - self._executor_config.peft_cache_config = tllm.PeftCacheConfig( - num_device_module_layer=max_lora_rank * num_lora_modules * - self.args.max_loras, - num_host_module_layer=max_lora_rank * num_lora_modules * - self.args.max_cpu_loras, + + peft_cache_config_model = PeftCacheConfig.create_from_pybind( + self._executor_config.peft_cache_config + ) if self._executor_config.peft_cache_config is not None else PeftCacheConfig( + ) + if lora_config.max_loras is not None: + peft_cache_config_model.num_device_module_layer = \ + max_lora_rank * num_lora_modules * lora_config.max_loras + if lora_config.max_cpu_loras is not None: + peft_cache_config_model.num_host_module_layer = \ + max_lora_rank * num_lora_modules * lora_config.max_cpu_loras + self._executor_config.peft_cache_config = peft_cache_config_model._to_pybind( ) + if self.args.decoding_config is not None: self._executor_config.decoding_config = self.args.decoding_config if self.args.guided_decoding_backend == 'xgrammar': @@ -860,7 +876,7 @@ def _build_model(self): postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, ), is_llm_executor=True, - lora_config=self.args.lora_config) + lora_config=lora_config) @append_docstring(TORCH_LLM_DOCSTRING) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 5406e7693d4..f4bed72294a 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -695,11 +695,12 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror): default=0, description= "number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache" - ) + ", affects host cache size and overrides value of host_cache_size") num_device_module_layer: int = Field( default=0, description= "number of max sized 1-layer 1-module sets of weights that can be stored in host cache" + ", affects device cache size and overrides value of device_cache_percent" ) optimal_adapter_size: int = Field( default= @@ -752,6 +753,24 @@ def _to_pybind(self): host_cache_size=self.host_cache_size, lora_prefetch_dir=self.lora_prefetch_dir) + @staticmethod + def create_from_pybind( + peft_cache_config: _PeftCacheConfig) -> "PeftCacheConfig": + return PeftCacheConfig( + num_host_module_layer=peft_cache_config.num_host_module_layer, + num_device_module_layer=peft_cache_config.num_device_module_layer, + optimal_adapter_size=peft_cache_config.optimal_adapter_size, + max_adapter_size=peft_cache_config.max_adapter_size, + num_put_workers=peft_cache_config.num_put_workers, + num_ensure_workers=peft_cache_config.num_ensure_workers, + num_copy_streams=peft_cache_config.num_copy_streams, + max_pages_per_block_host=peft_cache_config.max_pages_per_block_host, + max_pages_per_block_device=peft_cache_config. + max_pages_per_block_device, + device_cache_percent=peft_cache_config.device_cache_percent, + host_cache_size=peft_cache_config.host_cache_size, + lora_prefetch_dir=peft_cache_config.lora_prefetch_dir) + @PybindMirror.mirror_pybind_fields(_LookaheadDecodingConfig) class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror): @@ -1084,27 +1103,6 @@ class BaseLlmArgs(StrictBaseModel): # LoRA arguments enable_lora: bool = Field(default=False, description="Enable LoRA.") - max_lora_rank: Optional[int] = Field( - default=None, - description="The maximum LoRA rank.", - deprecated="Use lora_config.max_lora_rank instead.", - status="deprecated", - ) - - max_loras: int = Field( - default=4, - description="The maximum number of LoRA.", - deprecated="Use lora_config.max_loras instead.", - status="deprecated", - ) - - max_cpu_loras: int = Field( - default=4, - description="The maximum number of LoRA on CPU.", - deprecated="Use lora_config.max_cpu_loras instead.", - status="deprecated", - ) - lora_config: Optional[LoraConfig] = Field( default=None, description="LoRA configuration for the model.") @@ -1602,16 +1600,6 @@ def validate_speculative_config(self): @model_validator(mode="after") def validate_lora_config_consistency(self): if self.lora_config: - if self.max_lora_rank is not None: - logger.warning( - "max_lora_rank is ignored when lora_config is provided.") - if self.max_loras != self.lora_config.max_loras: - logger.warning( - "max_loras is ignored when lora_config is provided.") - if self.max_cpu_loras != self.lora_config.max_cpu_loras: - logger.warning( - "max_cpu_loras is ignored when lora_config is provided.") - if len(self.lora_config.lora_dir) == 0: # TODO [TRTLLM-5173] logger.warning( diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 9f42fdad20d..9cd1b80dc6d 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -203,8 +203,8 @@ class LoraConfig(DictConversion): max_lora_rank: int = 64 lora_target_modules: List[str] = field(default_factory=list) trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) - max_loras: int = 4 - max_cpu_loras: int = 4 + max_loras: int | None = None + max_cpu_loras: int | None = None def __post_init__(self): assert self.lora_ckpt_source in ["hf", "nemo"], ( From 07cde2973818996639ed645af7e36379552cf35b Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:19 +0000 Subject: [PATCH 02/21] Enabled use of LoraConfig in TRT_python flow, added tests of expected use of LoraConfig and PeftCacheConfig in LLM args Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 6 +-- tests/unittest/llmapi/test_llm.py | 64 +++++++++++++++++++++-- tests/unittest/llmapi/test_llm_pytorch.py | 57 +++++++++++++++++++- 3 files changed, 118 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index f4bed72294a..c7c530ba96c 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1493,10 +1493,10 @@ def validate_build_config_remaining(self): if self.parallel_config._world_size == 1 and self.build_config: self.build_config.plugin_config.nccl_plugin = None - if self.enable_lora and self.lora_config is None and self.backend != 'pytorch': + if self.enable_lora and self.backend != 'pytorch': self.build_config.plugin_config.lora_plugin = 'auto' - if self.max_lora_rank is not None: - self.build_config.lora_config.max_lora_rank = self.max_lora_rank + if self.lora_config is not None: + self.build_config.lora_config.max_lora_rank = self.lora_config.max_lora_rank if hasattr(self, 'enable_prompt_adapter') and self.enable_prompt_adapter: diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 4d7e4f127b4..90c2c6b973e 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -35,7 +35,8 @@ LookaheadDecodingConfig, MedusaDecodingConfig, RequestOutput) from tensorrt_llm.llmapi import TrtLlmArgs as LlmArgs -from tensorrt_llm.llmapi.llm_args import DynamicBatchConfig, SchedulerConfig +from tensorrt_llm.llmapi.llm_args import (DynamicBatchConfig, PeftCacheConfig, + SchedulerConfig) from tensorrt_llm.llmapi.llm_utils import (BuildConfig, QuantAlgo, QuantConfig, _ParallelConfig) from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer, @@ -50,7 +51,9 @@ # isort: off sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from gc_utils import assert_resource_freed -from llmapi.lora_test_utils import check_llama_7b_multi_unique_lora_adapters_from_request +from llmapi.lora_test_utils import ( + check_llama_7b_multi_lora_from_request_test_harness, + check_llama_7b_multi_unique_lora_adapters_from_request) from utils.llm_data import llm_models_root from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper, skip_single_gpu # isort: on @@ -1478,13 +1481,64 @@ def test_llama_7b_multi_lora_evict_load_new_adapters( lora_adapter_count_per_call, repeat_calls, repeats_per_call, + LLM, + enable_lora=True, + build_config=build_config, + fast_build=True) + + +def test_llama_7b_peft_cache_config_affects_peft_cache_size(): + """Tests that LLM arg of peft_cache_config affects the peft cache sizes. + + NOTE: The caller can't get the actual LoRA cache size without debug logs, so + to test whether it is affected by PeftCacheConfig LLM arg, a non-zero value + that's too small to contain a single adapter can be sent, which shall cause + a failure in init. + """ + # For LoRA checkpoints without finetuned embedding and lm_head, we can either: + # (1) specify lora_target_modules, or + # (2) provide a lora_dir to infer the lora_target_modules. + lora_config_no_cache_size_values = LoraConfig( + lora_target_modules=['attn_q', 'attn_k', 'attn_v'], max_lora_rank=8) + build_config = BuildConfig(lora_config=lora_config_no_cache_size_values) + + # Test that init fails on PeftCacheConfig.host_cache_size too small + with pytest.raises(RuntimeError): + check_llama_7b_multi_lora_from_request_test_harness( + LLM, + enable_lora=True, + build_config=build_config, + fast_build=True, + lora_config=lora_config_no_cache_size_values, + peft_cache_config=PeftCacheConfig(host_cache_size=1)) + + # Test that init fails on PeftCacheConfig.device_cache_percent too small + with pytest.raises(RuntimeError): + check_llama_7b_multi_lora_from_request_test_harness( + LLM, + enable_lora=True, + build_config=build_config, + fast_build=True, + lora_config=lora_config_no_cache_size_values, + peft_cache_config=PeftCacheConfig(device_cache_percent=0.000001)) + + +def test_llama_7b_lora_config_overrides_peft_cache_config(): + """Tests that cache size args in lora_config LLM arg override the cache size parameters in peft_cache_config LLM arg.""" + build_config = BuildConfig(lora_config=LoraConfig( + lora_target_modules=['attn_q', 'attn_k', 'attn_v'], max_lora_rank=8)) + check_llama_7b_multi_lora_from_request_test_harness( LLM, enable_lora=True, build_config=build_config, fast_build=True, - max_lora_rank=8, - max_loras=max_loras, - max_cpu_loras=max_cpu_loras) + lora_config=LoraConfig( + lora_target_modules=['attn_q', 'attn_k', 'attn_v'], + max_lora_rank=8, + max_loras=2, + max_cpu_loras=2), + peft_cache_config=PeftCacheConfig(host_cache_size=1, + device_cache_percent=0.000001)) @skip_gpu_memory_less_than_40gb diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 7e890693e50..fbafd9dab22 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -1,11 +1,15 @@ import pytest from tensorrt_llm import LLM +from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer from tensorrt_llm.sampling_params import SamplingParams # isort: off -from .lora_test_utils import check_llama_7b_multi_unique_lora_adapters_from_request, create_mock_nemo_lora_checkpoint +from .lora_test_utils import ( + check_llama_7b_multi_lora_from_request_test_harness, + check_llama_7b_multi_unique_lora_adapters_from_request, + create_mock_nemo_lora_checkpoint) from .test_llm import (get_model_path, global_kvcache_config, llama_model_path, llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts, @@ -288,6 +292,57 @@ def _check_contains_expected_message(stdout: str, stderr: str): assert _check_contains_expected_message(child_stdout, child_stderr) +def test_llama_7b_peft_cache_config_affects_peft_cache_size(): + """Tests that LLM arg of peft_cache_config affects the peft cache sizes. + + NOTE: The caller can't get the actual LoRA cache size without debug logs, so + to test whether it is affected by PeftCacheConfig LLM arg, a non-zero value + that's too small to contain a single adapter can be sent, which shall cause + a failure in init. + """ + # For LoRA checkpoints without finetuned embedding and lm_head, we can either: + # (1) specify lora_target_modules, or + # (2) provide a lora_dir to infer the lora_target_modules. + lora_config_no_cache_size_values = LoraConfig( + lora_target_modules=['attn_q', 'attn_k', 'attn_v'], max_lora_rank=8) + + # Test that init fails on PeftCacheConfig.host_cache_size too small + with pytest.raises(RuntimeError): + check_llama_7b_multi_lora_from_request_test_harness( + LLM, + lora_config=lora_config_no_cache_size_values, + peft_cache_config=PeftCacheConfig(host_cache_size=1), + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + cuda_graph_config=None) + + # Test that init fails on PeftCacheConfig.device_cache_percent too small + with pytest.raises(RuntimeError): + check_llama_7b_multi_lora_from_request_test_harness( + LLM, + lora_config=lora_config_no_cache_size_values, + peft_cache_config=PeftCacheConfig(device_cache_percent=0.000001), + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + cuda_graph_config=None) + + +def test_llama_7b_lora_config_overrides_peft_cache_config(): + """Tests that cache size args in lora_config LLM arg override the cache size parameters in peft_cache_config LLM arg.""" + check_llama_7b_multi_lora_from_request_test_harness( + LLM, + lora_config=LoraConfig( + lora_target_modules=['attn_q', 'attn_k', 'attn_v'], + max_lora_rank=8, + max_loras=2, + max_cpu_loras=2), + peft_cache_config=PeftCacheConfig(host_cache_size=1, + device_cache_percent=0.000001), + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + cuda_graph_config=None) + + # TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high # https://jirasw.nvidia.com/browse/TRTLLM-5045 @pytest.mark.skip(reason="https://nvbugs/5401210") From eabe7168b0dd2083d730fe805df8ba3e6ea72ffe Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:19 +0000 Subject: [PATCH 03/21] Improve comments in tests Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tests/unittest/llmapi/test_llm.py | 11 +++++------ tests/unittest/llmapi/test_llm_pytorch.py | 11 +++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 90c2c6b973e..58106831285 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1490,10 +1490,9 @@ def test_llama_7b_multi_lora_evict_load_new_adapters( def test_llama_7b_peft_cache_config_affects_peft_cache_size(): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. - NOTE: The caller can't get the actual LoRA cache size without debug logs, so - to test whether it is affected by PeftCacheConfig LLM arg, a non-zero value - that's too small to contain a single adapter can be sent, which shall cause - a failure in init. + NOTE: The caller can't get the actual LoRA cache sizes, so we instead we + test that it fails when configured with a value too small to contain a + single adapter. """ # For LoRA checkpoints without finetuned embedding and lm_head, we can either: # (1) specify lora_target_modules, or @@ -1502,7 +1501,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): lora_target_modules=['attn_q', 'attn_k', 'attn_v'], max_lora_rank=8) build_config = BuildConfig(lora_config=lora_config_no_cache_size_values) - # Test that init fails on PeftCacheConfig.host_cache_size too small + # Test that too small PeftCacheConfig.host_cache_size causes failure with pytest.raises(RuntimeError): check_llama_7b_multi_lora_from_request_test_harness( LLM, @@ -1512,7 +1511,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): lora_config=lora_config_no_cache_size_values, peft_cache_config=PeftCacheConfig(host_cache_size=1)) - # Test that init fails on PeftCacheConfig.device_cache_percent too small + # Test that too small PeftCacheConfig.device_cache_percent causes failure with pytest.raises(RuntimeError): check_llama_7b_multi_lora_from_request_test_harness( LLM, diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index fbafd9dab22..315b3648c34 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -295,10 +295,9 @@ def _check_contains_expected_message(stdout: str, stderr: str): def test_llama_7b_peft_cache_config_affects_peft_cache_size(): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. - NOTE: The caller can't get the actual LoRA cache size without debug logs, so - to test whether it is affected by PeftCacheConfig LLM arg, a non-zero value - that's too small to contain a single adapter can be sent, which shall cause - a failure in init. + NOTE: The caller can't get the actual LoRA cache sizes, so we instead we + test that it fails when configured with a value too small to contain a + single adapter. """ # For LoRA checkpoints without finetuned embedding and lm_head, we can either: # (1) specify lora_target_modules, or @@ -306,7 +305,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): lora_config_no_cache_size_values = LoraConfig( lora_target_modules=['attn_q', 'attn_k', 'attn_v'], max_lora_rank=8) - # Test that init fails on PeftCacheConfig.host_cache_size too small + # Test that too small PeftCacheConfig.host_cache_size causes failure with pytest.raises(RuntimeError): check_llama_7b_multi_lora_from_request_test_harness( LLM, @@ -316,7 +315,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): # TODO: remove this once we have a proper fix for CUDA graph in LoRA cuda_graph_config=None) - # Test that init fails on PeftCacheConfig.device_cache_percent too small + # Test that too small PeftCacheConfig.device_cache_percent causes failure with pytest.raises(RuntimeError): check_llama_7b_multi_lora_from_request_test_harness( LLM, From d1a896f52afa7353ddeaa09d1f794bf0ec660433 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:20 +0000 Subject: [PATCH 04/21] Correct mistake in PeftCacheConfig.num_device_module_layer description Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index c7c530ba96c..82e707c7d9e 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -699,7 +699,7 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror): num_device_module_layer: int = Field( default=0, description= - "number of max sized 1-layer 1-module sets of weights that can be stored in host cache" + "number of max sized 1-layer 1-module sets of weights that can be stored in device cache" ", affects device cache size and overrides value of device_cache_percent" ) optimal_adapter_size: int = Field( From e90872a7bf93b8cbcb16091d41de36da146d3d48 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:20 +0000 Subject: [PATCH 05/21] Add validation of unsupported field in peft cache manager Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 82e707c7d9e..d59c9eab3a8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -736,7 +736,7 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror): default=None, description= "folder to store the LoRA weights we hope to load during engine initialization" - ) + ", not supported with pytorch backend") def _to_pybind(self): return _PeftCacheConfig( @@ -1626,6 +1626,13 @@ def validate_lora_config_consistency(self): default_trtllm_modules_to_hf_modules.keys()) return self + @model_validator(mode="after") + def validate_peft_cache_config(self): + if self.backend == "pytorch" and self.peft_cache_config.lora_prefetch_dir is not None: + logger.warning( + "LoRA prefetch is not supported with pytorch backend") + return self + def _update_plugin_config(self, key: str, value: Any): setattr(self.build_config.plugin_config, key, value) From 7e4e37c0adf658f1bbaa0803a500a012a4f4b79e Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:21 +0000 Subject: [PATCH 06/21] Fix docstring line length Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tests/unittest/llmapi/test_llm.py | 4 +++- tests/unittest/llmapi/test_llm_pytorch.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 58106831285..c04085e033f 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1523,7 +1523,9 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): def test_llama_7b_lora_config_overrides_peft_cache_config(): - """Tests that cache size args in lora_config LLM arg override the cache size parameters in peft_cache_config LLM arg.""" + """Tests that cache size args in lora_config LLM arg override the cache size + parameters in peft_cache_config LLM arg. + """ # noqa: D205 build_config = BuildConfig(lora_config=LoraConfig( lora_target_modules=['attn_q', 'attn_k', 'attn_v'], max_lora_rank=8)) check_llama_7b_multi_lora_from_request_test_harness( diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 315b3648c34..693c4306acf 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -327,7 +327,9 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): def test_llama_7b_lora_config_overrides_peft_cache_config(): - """Tests that cache size args in lora_config LLM arg override the cache size parameters in peft_cache_config LLM arg.""" + """Tests that cache size args in lora_config LLM arg override the cache size + parameters in peft_cache_config LLM arg. + """ # noqa: D205 check_llama_7b_multi_lora_from_request_test_harness( LLM, lora_config=LoraConfig( From 004eaf98946a9a9c3fcc7a0a745e0bfa157e9811 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:22 +0000 Subject: [PATCH 07/21] Fix validate_peft_cache_config Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index d59c9eab3a8..1cee696348a 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1628,7 +1628,7 @@ def validate_lora_config_consistency(self): @model_validator(mode="after") def validate_peft_cache_config(self): - if self.backend == "pytorch" and self.peft_cache_config.lora_prefetch_dir is not None: + if self.backend == "pytorch" and self.peft_cache_config is not None and self.peft_cache_config.lora_prefetch_dir is not None: logger.warning( "LoRA prefetch is not supported with pytorch backend") return self From 1afafa75c9a68a656c999159946aff480e70e44c Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:22 +0000 Subject: [PATCH 08/21] Fix validate_peft_cache_config formatting Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1cee696348a..e1f4003335d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1628,9 +1628,10 @@ def validate_lora_config_consistency(self): @model_validator(mode="after") def validate_peft_cache_config(self): - if self.backend == "pytorch" and self.peft_cache_config is not None and self.peft_cache_config.lora_prefetch_dir is not None: - logger.warning( - "LoRA prefetch is not supported with pytorch backend") + if self.backend == "pytorch" and self.peft_cache_config is not None: + if self.peft_cache_config.lora_prefetch_dir is not None: + logger.warning( + "LoRA prefetch is not supported with pytorch backend") return self def _update_plugin_config(self, key: str, value: Any): From c486af2939e4212afce8a39adc8147c496733b3a Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:23 +0000 Subject: [PATCH 09/21] Fix lora_prefetch_dir description and 'unsupported warning' message, clarify test Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 10 ++++------ tests/unittest/llmapi/test_llm.py | 10 ++++++---- tests/unittest/llmapi/test_llm_pytorch.py | 10 ++++++---- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index e1f4003335d..3f32a3e7fa8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -735,8 +735,8 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror): lora_prefetch_dir: Optional[str] = Field( default=None, description= - "folder to store the LoRA weights we hope to load during engine initialization" - ", not supported with pytorch backend") + "folder to store the LoRA weights we hope to load during engine initialization, currently not supported" + ) def _to_pybind(self): return _PeftCacheConfig( @@ -1628,10 +1628,8 @@ def validate_lora_config_consistency(self): @model_validator(mode="after") def validate_peft_cache_config(self): - if self.backend == "pytorch" and self.peft_cache_config is not None: - if self.peft_cache_config.lora_prefetch_dir is not None: - logger.warning( - "LoRA prefetch is not supported with pytorch backend") + if self.peft_cache_config is not None and self.peft_cache_config.lora_prefetch_dir is not None: + logger.warning("LoRA prefetch is not supported") return self def _update_plugin_config(self, key: str, value: Any): diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index c04085e033f..ef01a93d67c 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1509,7 +1509,8 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): build_config=build_config, fast_build=True, lora_config=lora_config_no_cache_size_values, - peft_cache_config=PeftCacheConfig(host_cache_size=1)) + peft_cache_config=PeftCacheConfig( + host_cache_size=1)) # size in bytes # Test that too small PeftCacheConfig.device_cache_percent causes failure with pytest.raises(RuntimeError): @@ -1519,7 +1520,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): build_config=build_config, fast_build=True, lora_config=lora_config_no_cache_size_values, - peft_cache_config=PeftCacheConfig(device_cache_percent=0.000001)) + peft_cache_config=PeftCacheConfig(device_cache_percent=0.0000001)) def test_llama_7b_lora_config_overrides_peft_cache_config(): @@ -1538,8 +1539,9 @@ def test_llama_7b_lora_config_overrides_peft_cache_config(): max_lora_rank=8, max_loras=2, max_cpu_loras=2), - peft_cache_config=PeftCacheConfig(host_cache_size=1, - device_cache_percent=0.000001)) + peft_cache_config=PeftCacheConfig( + host_cache_size=1, # size in bytes + device_cache_percent=0.0000001)) @skip_gpu_memory_less_than_40gb diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 693c4306acf..39bdd7a1c85 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -310,7 +310,8 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): check_llama_7b_multi_lora_from_request_test_harness( LLM, lora_config=lora_config_no_cache_size_values, - peft_cache_config=PeftCacheConfig(host_cache_size=1), + peft_cache_config=PeftCacheConfig( + host_cache_size=1), # size in bytes # Disable CUDA graph # TODO: remove this once we have a proper fix for CUDA graph in LoRA cuda_graph_config=None) @@ -320,7 +321,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): check_llama_7b_multi_lora_from_request_test_harness( LLM, lora_config=lora_config_no_cache_size_values, - peft_cache_config=PeftCacheConfig(device_cache_percent=0.000001), + peft_cache_config=PeftCacheConfig(device_cache_percent=0.0000001), # Disable CUDA graph # TODO: remove this once we have a proper fix for CUDA graph in LoRA cuda_graph_config=None) @@ -337,8 +338,9 @@ def test_llama_7b_lora_config_overrides_peft_cache_config(): max_lora_rank=8, max_loras=2, max_cpu_loras=2), - peft_cache_config=PeftCacheConfig(host_cache_size=1, - device_cache_percent=0.000001), + peft_cache_config=PeftCacheConfig( + host_cache_size=1, # size in bytes + device_cache_percent=0.0000001), # Disable CUDA graph # TODO: remove this once we have a proper fix for CUDA graph in LoRA cuda_graph_config=None) From 138c4b1571c1e90cc1ffd547a4fa6f168eeefd83 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:23 +0000 Subject: [PATCH 10/21] Fix tests to configure lora cache size by number of adapters for test stability Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tests/unittest/llmapi/test_llm.py | 4 ++-- tests/unittest/llmapi/test_llm_pytorch.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index ef01a93d67c..7f05e6e0e1f 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1427,11 +1427,11 @@ def llama_v2_13b_lora_from_dir_test_harness(**llm_kwargs): hf_lora_dir = get_model_path("llama-models-v2/chinese-llama-2-lora-13b") # For LoRA checkpoints with finetuned embedding and lm_head, lora_dir must be provided at build time. - build_config = BuildConfig(lora_config=LoraConfig(lora_dir=[hf_lora_dir])) + build_config = BuildConfig(lora_config=LoraConfig( + lora_dir=[hf_lora_dir], max_lora_rank=64, max_loras=2, max_cpu_loras=2)) llm = LLM(hf_model_dir, tokenizer=hf_lora_dir, enable_lora=True, - max_lora_rank=64, build_config=build_config, fast_build=True, **llm_kwargs) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 39bdd7a1c85..f9e636ec678 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -140,7 +140,9 @@ def test_llm_with_postprocess_parallel_and_result_handler(streaming): def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None: lora_config = LoraConfig( lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"], - max_lora_rank=8) + max_lora_rank=8, + max_loras=2, + max_cpu_loras=2) llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf", lora_config=lora_config, **llm_kwargs) @@ -174,7 +176,7 @@ def test_llama_7b_lora(): @skip_gpu_memory_less_than_40gb def test_llama_7b_lora_default_modules() -> None: - lora_config = LoraConfig(max_lora_rank=64) + lora_config = LoraConfig(max_lora_rank=64, max_loras=2, max_cpu_loras=2) hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf" @@ -418,7 +420,9 @@ def test_codellama_fp8_with_bf16_lora() -> None: lora_config = LoraConfig(lora_dir=lora_paths, lora_target_modules=target_modules, - max_lora_rank=8) + max_lora_rank=8, + max_loras=2, + max_cpu_loras=2) llm = LLM(model_dir, quant_config=quant_config, lora_config=lora_config) @@ -468,7 +472,9 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: trtllm_lora_config = LoraConfig(lora_dir=lora_paths, lora_target_modules=target_modules, - max_lora_rank=8) + max_lora_rank=8, + max_loras=2, + max_cpu_loras=2) llm = LLM(model_dir, lora_config=trtllm_lora_config) prompts = [ From e26ca0ab8c2a3c80af701e40aebd3192dc836b32 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:24 +0000 Subject: [PATCH 11/21] Fix tests to API update - use LoraConfig instead of base LLM args for LoRA args Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- examples/llm-api/llm_multilora.py | 4 +++- tests/unittest/llmapi/test_llm_multi_gpu.py | 3 --- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/llm-api/llm_multilora.py b/examples/llm-api/llm_multilora.py index 4e3598d1c1b..a46606ae233 100644 --- a/examples/llm-api/llm_multilora.py +++ b/examples/llm-api/llm_multilora.py @@ -23,7 +23,9 @@ def main(): build_config.lora_config = LoraConfig(lora_dir=[lora_dir1]) llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", enable_lora=True, - max_lora_rank=64, + lora_config=LoraConfig(max_lora_rank=64, + max_loras=3, + max_cpu_loras=3), build_config=build_config) # Sample prompts diff --git a/tests/unittest/llmapi/test_llm_multi_gpu.py b/tests/unittest/llmapi/test_llm_multi_gpu.py index ecddfbe6a04..0812fea853d 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu.py @@ -274,9 +274,6 @@ def test_llama_7b_multi_lora_tp2(): enable_lora=True, build_config=BuildConfig(lora_config=lora_config), fast_build=True, - max_lora_rank=lora_config.max_lora_rank, - max_loras=lora_config.max_loras, - max_cpu_loras=lora_config.max_cpu_loras, kv_cache_config=global_kv_cache_config) From ef99dd227f34740188245d02e85fbb468d758d36 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:25 +0000 Subject: [PATCH 12/21] Fix tests to explicitly configure lora_config's max_loras and max_cpu_loras for stability Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tests/unittest/llmapi/apps/_test_openai_lora.py | 4 +++- tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unittest/llmapi/apps/_test_openai_lora.py b/tests/unittest/llmapi/apps/_test_openai_lora.py index c37a8db2b33..313304a2510 100644 --- a/tests/unittest/llmapi/apps/_test_openai_lora.py +++ b/tests/unittest/llmapi/apps/_test_openai_lora.py @@ -36,7 +36,9 @@ def temp_extra_llm_api_options_file(): extra_llm_api_options_dict = { "lora_config": { "lora_target_modules": ['attn_q', 'attn_k', 'attn_v'], - "max_lora_rank": 8 + "max_lora_rank": 8, + "max_loras": 4, + "max_cpu_loras": 4, } } diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py index 2248250b834..e94c30662b1 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py @@ -25,7 +25,9 @@ def temp_extra_llm_api_options_file(): extra_llm_api_options_dict = { "lora_config": { "lora_target_modules": ['attn_q', 'attn_k', 'attn_v'], - "max_lora_rank": 8 + "max_lora_rank": 8, + "max_loras": 4, + "max_cpu_loras": 4, } } From 797715e03d97a57dcbaec7853074e049e5b56645 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:25 +0000 Subject: [PATCH 13/21] Define default values in PeftCacheConfig model class for device_cache_percent and host_cache_size, improve device_cache_percent description Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 3f32a3e7fa8..51cd921fa1a 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -727,11 +727,13 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror): max_pages_per_block_device: int = Field( default=8, description="Number of cache pages per allocation block (device)") - device_cache_percent: Optional[float] = Field( - default=None, - description="percent of memory after engine load to use for cache") - host_cache_size: Optional[int] = Field( - default=None, description="size in bytes to use for host cache") + device_cache_percent: float = Field( + default=0.02, + description= + "Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1" + ) + host_cache_size: int = Field( + default=1024**3, description="size in bytes to use for host cache") lora_prefetch_dir: Optional[str] = Field( default=None, description= From 53b4233e66ccd245f5c8342afa92aae070bee807 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:26 +0000 Subject: [PATCH 14/21] Add default value to description Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 51cd921fa1a..0b226c275ff 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -731,9 +731,10 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror): default=0.02, description= "Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1" - ) + ", defaults to 2%") host_cache_size: int = Field( - default=1024**3, description="size in bytes to use for host cache") + default=1024**3, + description="size in bytes to use for host cache, defaults to 1GiB") lora_prefetch_dir: Optional[str] = Field( default=None, description= From 0d51a80bd008b95431c695d358999320c058601a Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:26 +0000 Subject: [PATCH 15/21] Fix PeftCacheConfig.create_from_pybind after changing python fields to be non-optional with default values Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 0b226c275ff..800e7ef4d00 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -759,6 +759,15 @@ def _to_pybind(self): @staticmethod def create_from_pybind( peft_cache_config: _PeftCacheConfig) -> "PeftCacheConfig": + # Some of the properties are optional in CPP but in python they have a default value and aren't optional, + # so copy their value only if they have a value in the CPP instance. + extra_kwargs = {} + if peft_cache_config.device_cache_percent is not None: + extra_kwargs[ + "device_cache_percent"] = peft_cache_config.device_cache_percent + if peft_cache_config.host_cache_size is not None: + extra_kwargs["host_cache_size"] = peft_cache_config.host_cache_size + return PeftCacheConfig( num_host_module_layer=peft_cache_config.num_host_module_layer, num_device_module_layer=peft_cache_config.num_device_module_layer, @@ -770,9 +779,8 @@ def create_from_pybind( max_pages_per_block_host=peft_cache_config.max_pages_per_block_host, max_pages_per_block_device=peft_cache_config. max_pages_per_block_device, - device_cache_percent=peft_cache_config.device_cache_percent, - host_cache_size=peft_cache_config.host_cache_size, - lora_prefetch_dir=peft_cache_config.lora_prefetch_dir) + lora_prefetch_dir=peft_cache_config.lora_prefetch_dir, + **extra_kwargs) @PybindMirror.mirror_pybind_fields(_LookaheadDecodingConfig) From e0fcbeb798f1314128f5bcbbeefaa98aacb1592d Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:27 +0000 Subject: [PATCH 16/21] Fix examples/llm-api/llm_multilora.py - use one LoraConfig Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- examples/llm-api/llm_multilora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/llm-api/llm_multilora.py b/examples/llm-api/llm_multilora.py index a46606ae233..7329d9f5e9f 100644 --- a/examples/llm-api/llm_multilora.py +++ b/examples/llm-api/llm_multilora.py @@ -20,12 +20,12 @@ def main(): # Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config. # This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support. build_config = BuildConfig() - build_config.lora_config = LoraConfig(lora_dir=[lora_dir1]) + build_config.lora_config = LoraConfig(lora_dir=[lora_dir1], + max_lora_rank=64, + max_loras=3, + max_cpu_loras=3) llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", enable_lora=True, - lora_config=LoraConfig(max_lora_rank=64, - max_loras=3, - max_cpu_loras=3), build_config=build_config) # Sample prompts From 61a994bb80a57f35559581950cd10bffeb9203fe Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:28 +0000 Subject: [PATCH 17/21] Fix examples/llm-api/llm_multilora.py to not use BuildConfig that's irrelevant to pytorch backend Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- examples/llm-api/llm_multilora.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/examples/llm-api/llm_multilora.py b/examples/llm-api/llm_multilora.py index 7329d9f5e9f..60795b6c60a 100644 --- a/examples/llm-api/llm_multilora.py +++ b/examples/llm-api/llm_multilora.py @@ -5,7 +5,6 @@ from tensorrt_llm import LLM from tensorrt_llm.executor import LoRARequest -from tensorrt_llm.llmapi import BuildConfig from tensorrt_llm.lora_manager import LoraConfig @@ -19,14 +18,12 @@ def main(): # Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config. # This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support. - build_config = BuildConfig() - build_config.lora_config = LoraConfig(lora_dir=[lora_dir1], - max_lora_rank=64, - max_loras=3, - max_cpu_loras=3) + lora_config = LoraConfig(lora_dir=[lora_dir1], + max_lora_rank=64, + max_loras=3, + max_cpu_loras=3) llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - enable_lora=True, - build_config=build_config) + lora_config=lora_config) # Sample prompts prompts = [ From 191a0ed9f95570cf854024091a460c78641cd183 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:29 +0000 Subject: [PATCH 18/21] Changed create_from_pybind method to be a more generic classmethod in PybindMirror, updated its PeftCacheConfig tests accordingly, removed default values from description, raise exception when unused peft_cache_config.lora_prefetch_dir was set instead of writing a warning log message Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 96 +++++++++++++++++--------- tests/unittest/llmapi/test_llm_args.py | 65 +++++++++++++++-- 2 files changed, 126 insertions(+), 35 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 800e7ef4d00..26fd9133dee 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3,12 +3,13 @@ import json import math import os +import types from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum, EnumMeta from pathlib import Path from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, - TypeAlias, Union) + Type, TypeAlias, TypeVar, Union, get_args, get_origin) import torch import yaml @@ -61,6 +62,8 @@ # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import +TypeBaseModel = TypeVar("T", bound=BaseModel) + def Field(default: Any = ..., *, @@ -598,6 +601,62 @@ def pybind_equals(obj0, obj1): return False return True + @classmethod + def from_pybind(cls: Type[TypeBaseModel], + pybind_instance: "PybindMirror") -> TypeBaseModel: + """Construct an instance of the given class from the fields in the given + pybind class. + + Args: + cls: Type of the class to construct, must be a subclass of pydantic + BaseModel + pybind_instance: Instance of the pybind class to construct from its + fields + + Notes: + When a field value is None in the pybind class, but it's not + optional and has a default value in the BaseModel class, it would + get the default value defined in the BaseModel class. + + Returns: + Instance of the given class, populated with the fields of the given + pybind instance + """ # noqa: D205 + assert issubclass(cls, BaseModel) + + # Some of the fields are optional in the C++ class but in python they aren't + # optional and have a default value, so copy the value from C++ instance + # only if it has a value, so otherwise the default value defined in the + # python class would be set. + def _is_optional_type(annotation: Any) -> bool: + """Returns True if a type annotation represents an Optional type + (Optional[X]) or a Union type that includes None (Union[X, Y, None] + or X | Y | None). + """ # noqa: D205 + origin = get_origin(annotation) + args = get_args(annotation) + + # Union is for Optional[x] + # UnionType is for the new | operation in Python 3.10+ + return (origin is Union + or origin is types.UnionType) and type(None) in args + + fields_non_optional_with_default_value_in_basemodel = { + field_name + for field_name, field_info in cls.model_fields.items() + if not (_is_optional_type(field_info.annotation) + and field_info.is_required()) + } + + kwargs = {} + cpp_fields = PybindMirror.get_pybind_variable_fields( + type(pybind_instance)) + for field_name in cpp_fields: + field_value = getattr(pybind_instance, field_name) + if field_value is not None or field_name not in fields_non_optional_with_default_value_in_basemodel: + kwargs[field_name] = field_value + return cls(**kwargs) + class PybindMirrorMeta(type(PybindMirror)): pass @@ -731,10 +790,9 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror): default=0.02, description= "Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1" - ", defaults to 2%") + ) host_cache_size: int = Field( - default=1024**3, - description="size in bytes to use for host cache, defaults to 1GiB") + default=1024**3, description="size in bytes to use for host cache") lora_prefetch_dir: Optional[str] = Field( default=None, description= @@ -756,32 +814,6 @@ def _to_pybind(self): host_cache_size=self.host_cache_size, lora_prefetch_dir=self.lora_prefetch_dir) - @staticmethod - def create_from_pybind( - peft_cache_config: _PeftCacheConfig) -> "PeftCacheConfig": - # Some of the properties are optional in CPP but in python they have a default value and aren't optional, - # so copy their value only if they have a value in the CPP instance. - extra_kwargs = {} - if peft_cache_config.device_cache_percent is not None: - extra_kwargs[ - "device_cache_percent"] = peft_cache_config.device_cache_percent - if peft_cache_config.host_cache_size is not None: - extra_kwargs["host_cache_size"] = peft_cache_config.host_cache_size - - return PeftCacheConfig( - num_host_module_layer=peft_cache_config.num_host_module_layer, - num_device_module_layer=peft_cache_config.num_device_module_layer, - optimal_adapter_size=peft_cache_config.optimal_adapter_size, - max_adapter_size=peft_cache_config.max_adapter_size, - num_put_workers=peft_cache_config.num_put_workers, - num_ensure_workers=peft_cache_config.num_ensure_workers, - num_copy_streams=peft_cache_config.num_copy_streams, - max_pages_per_block_host=peft_cache_config.max_pages_per_block_host, - max_pages_per_block_device=peft_cache_config. - max_pages_per_block_device, - lora_prefetch_dir=peft_cache_config.lora_prefetch_dir, - **extra_kwargs) - @PybindMirror.mirror_pybind_fields(_LookaheadDecodingConfig) class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror): @@ -1640,7 +1672,9 @@ def validate_lora_config_consistency(self): @model_validator(mode="after") def validate_peft_cache_config(self): if self.peft_cache_config is not None and self.peft_cache_config.lora_prefetch_dir is not None: - logger.warning("LoRA prefetch is not supported") + raise ValueError( + f"lora_prefetch_dir was set to '{self.peft_cache_config.lora_prefetch_dir}' " + "while LoRA prefetch is not supported") return self def _update_plugin_config(self, key: str, value: Any): diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index d6990ac745c..acb831837cd 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -223,10 +223,6 @@ def test_SchedulerConfig_declaration(): config.dynamic_batch_config._to_pybind()) -def test_PeftCacheConfig_default_values(): - check_defaults(PeftCacheConfig, tle.PeftCacheConfig) - - def test_PeftCacheConfig_declaration(): config = PeftCacheConfig(num_host_module_layer=1, num_device_module_layer=1, @@ -256,6 +252,67 @@ def test_PeftCacheConfig_declaration(): assert pybind_config.lora_prefetch_dir == "." +def test_PeftCacheConfig_from_pybind(): + pybind_config = tle.PeftCacheConfig(num_host_module_layer=1, + num_device_module_layer=1, + optimal_adapter_size=64, + max_adapter_size=128, + num_put_workers=1, + num_ensure_workers=1, + num_copy_streams=1, + max_pages_per_block_host=24, + max_pages_per_block_device=8, + device_cache_percent=0.5, + host_cache_size=1024, + lora_prefetch_dir=".") + + config = PeftCacheConfig.from_pybind(pybind_config) + assert config.num_host_module_layer == 1 + assert config.num_device_module_layer == 1 + assert config.optimal_adapter_size == 64 + assert config.max_adapter_size == 128 + assert config.num_put_workers == 1 + assert config.num_ensure_workers == 1 + assert config.num_copy_streams == 1 + assert config.max_pages_per_block_host == 24 + assert config.max_pages_per_block_device == 8 + assert config.device_cache_percent == 0.5 + assert config.host_cache_size == 1024 + assert config.lora_prefetch_dir == "." + + +def test_PeftCacheConfig_from_pybind_gets_python_only_default_values_when_none( +): + pybind_config = tle.PeftCacheConfig(num_host_module_layer=1, + num_device_module_layer=1, + optimal_adapter_size=64, + max_adapter_size=128, + num_put_workers=1, + num_ensure_workers=1, + num_copy_streams=1, + max_pages_per_block_host=24, + max_pages_per_block_device=8, + device_cache_percent=None, + host_cache_size=None, + lora_prefetch_dir=".") + + config = PeftCacheConfig.from_pybind(pybind_config) + assert config.num_host_module_layer == 1 + assert config.num_device_module_layer == 1 + assert config.optimal_adapter_size == 64 + assert config.max_adapter_size == 128 + assert config.num_put_workers == 1 + assert config.num_ensure_workers == 1 + assert config.num_copy_streams == 1 + assert config.max_pages_per_block_host == 24 + assert config.max_pages_per_block_device == 8 + assert config.device_cache_percent == PeftCacheConfig.model_fields[ + "device_cache_percent"].default + assert config.host_cache_size == PeftCacheConfig.model_fields[ + "host_cache_size"].default + assert config.lora_prefetch_dir == "." + + def test_update_llm_args_with_extra_dict_with_nested_dict(): llm_api_args_dict = { "model": From 8cca1946952672630cb336b4a73a7ead971f3e16 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:29 +0000 Subject: [PATCH 19/21] Minor docstring fix Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 26fd9133dee..d5b626321b3 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -605,7 +605,7 @@ def pybind_equals(obj0, obj1): def from_pybind(cls: Type[TypeBaseModel], pybind_instance: "PybindMirror") -> TypeBaseModel: """Construct an instance of the given class from the fields in the given - pybind class. + pybind class instance. Args: cls: Type of the class to construct, must be a subclass of pydantic From 391d0f9d6289c5e658592d2cb96c5dcf358e7e81 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:30 +0000 Subject: [PATCH 20/21] Fix rename Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 2 +- tensorrt_llm/llmapi/llm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 24b9fc12f0b..04ff612670b 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -482,7 +482,7 @@ def create_py_executor_instance( num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \ len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) - peft_cache_config_model = PeftCacheConfig.create_from_pybind( + peft_cache_config_model = PeftCacheConfig.from_pybind( executor_config.peft_cache_config ) if executor_config.peft_cache_config is not None else PeftCacheConfig( ) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 554e1b739df..73b576b3c8f 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -823,7 +823,7 @@ def _build_model(self): num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \ len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) - peft_cache_config_model = PeftCacheConfig.create_from_pybind( + peft_cache_config_model = PeftCacheConfig.from_pybind( self._executor_config.peft_cache_config ) if self._executor_config.peft_cache_config is not None else PeftCacheConfig( ) From bce06ad96b48d196804ca20224f681f7fe5a08da Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:20:31 +0000 Subject: [PATCH 21/21] Fix test_ptp_quickstart_multimodal_phi4mm - for stability set lora cache sizes, fix incorrect lora request creation Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- examples/llm-api/quickstart_multimodal.py | 3 +++ tensorrt_llm/_torch/models/modeling_phi4mm.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/llm-api/quickstart_multimodal.py b/examples/llm-api/quickstart_multimodal.py index 1f45da90eb4..fc18671ee28 100644 --- a/examples/llm-api/quickstart_multimodal.py +++ b/examples/llm-api/quickstart_multimodal.py @@ -148,6 +148,9 @@ def main(): models_module = importlib.import_module('tensorrt_llm._torch.models') model_class = getattr(models_module, args.auto_model_name) lora_config = model_class.lora_config(args.model_dir) + # For stability - explicitly set the LoRA GPU cache & CPU cache to have space for 2 adapters + lora_config.max_loras = 2 + lora_config.max_cpu_loras = 2 llm, sampling_params = setup_llm(args, lora_config=lora_config) diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index f0fddbef7b8..b5ad4f45203 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -271,16 +271,16 @@ def lora_request(num_requests: int, modality: str, base_model_dir: str): if modality == "image" or modality == "image_audio": lora_request = [ LoRARequest( - lora_name=f"vision-lora-{i}", - lora_int_id=i, + lora_name="vision-lora", + lora_int_id=0, lora_path=f"{base_model_dir}/vision-lora", ) for i in range(num_requests) ] elif modality == "audio": lora_request = [ LoRARequest( - lora_name=f"speech-lora-{i}", - lora_int_id=i, + lora_name="speech-lora", + lora_int_id=1, lora_path=f"{base_model_dir}/speech-lora", ) for i in range(num_requests) ]