1010import tensorrt_llm
1111import tensorrt_llm .bindings
1212from tensorrt_llm .bindings .BuildInfo import ENABLE_MULTI_DEVICE
13+ from tensorrt_llm .lora_manager import LoraConfig , LoraManager , LoraModelConfig
1314from tensorrt_llm .sampling_params import SamplingParams
1415
15- from ..._utils import binding_dtype_size , nvtx_range
16+ from ..._utils import binding_dtype_size , binding_to_str_dtype , nvtx_range
1617from ...logger import logger
1718from ...mapping import Mapping
1819from .llm_request import (LlmRequest , LlmRequestState , SamplingConfig ,
@@ -1170,6 +1171,7 @@ class PeftCacheManager(BaseResourceManager):
11701171
11711172 def __init__ (self ,
11721173 peft_cache_config : PeftCacheConfig ,
1174+ lora_config : LoraConfig ,
11731175 model_config : ModelConfig ,
11741176 world_config : WorldConfig | None = None ):
11751177 import tensorrt_llm .bindings as _tb
@@ -1200,8 +1202,36 @@ def __init__(self,
12001202 model_config = model_config ,
12011203 world_config = world_config ,
12021204 buffer_manager = buffer_manager )
1205+ self ._lora_config = lora_config
1206+ self ._lora_model_config = LoraModelConfig (
1207+ lora_config .lora_target_modules ,
1208+ lora_config .trtllm_modules_to_hf_modules , model_config .hidden_size ,
1209+ binding_to_str_dtype (model_config .data_type ))
1210+ self ._lora_manager = LoraManager ()
12031211
12041212 def add_request_peft (self , request : LlmRequest ):
1213+ if request .lora_task_id is not None :
1214+ is_task_cached = self .impl .is_task_cached (request .lora_task_id )
1215+ if is_task_cached :
1216+ # PeftCacheManager::addRequestPeft in CPP doesn't allow having only one of [config tensor, weights
1217+ # tensor] without the other. Since there's no need for any of them when the LoRA adapter is already
1218+ # cached, we can safely remove both from the request.
1219+ request .remove_lora_tensors ()
1220+ elif request .lora_weights is None and request .py_lora_path :
1221+ self ._lora_manager .load_from_ckpt (
1222+ [request .py_lora_path ],
1223+ model_config = self ._lora_model_config ,
1224+ runtime_mapping = None ,
1225+ uids = [request .lora_task_id ],
1226+ ckpt_source = self ._lora_config .lora_ckpt_source )
1227+ request .lora_weights = self ._lora_manager .cpp_lora_weights [
1228+ request .lora_task_id ]
1229+
1230+ # PeftCacheManager CPP implementation expects an extra dim at index 0
1231+ if request .lora_weights is not None :
1232+ request .lora_weights = request .lora_weights .unsqueeze (0 )
1233+ if request .lora_config is not None :
1234+ request .lora_config = request .lora_config .unsqueeze (0 )
12051235 self .impl .add_request_peft (request , True )
12061236
12071237 def ensure_batch (self ,
@@ -1221,12 +1251,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
12211251 context_batch = scheduled_batch .context_requests
12221252 generation_batch = scheduled_batch .generation_requests
12231253 for req in context_batch :
1224- if req .lora_weights is not None and req .lora_config is not None :
1225- req .lora_weights = req .lora_weights .reshape (
1226- [1 ] + list (req .lora_weights .shape ))
1227- req .lora_config = req .lora_config .reshape (
1228- [1 ] + list (req .lora_config .shape ))
1229- self .impl .add_request_peft (req , True )
1254+ self .add_request_peft (req )
12301255
12311256 py_lora_task_layer_module_configs = self .impl .ensure_batch (
12321257 context_batch , generation_batch , False )
0 commit comments