2626from  tensorrt_llm .inputs .multimodal  import  (MultimodalParams ,
2727                                            MultimodalRuntimeData )
2828from  tensorrt_llm .logger  import  logger 
29- from  tensorrt_llm .lora_manager  import  LoraConfig , LoraModelConfig 
29+ from  tensorrt_llm .lora_manager  import  LoraConfig , LoraManager ,  LoraModelConfig 
3030from  tensorrt_llm .mapping  import  Mapping 
3131from  tensorrt_llm .models .modeling_utils  import  QuantAlgo 
3232from  tensorrt_llm .quantization .utils .fp4_utils  import  float4_e2m1x2 
@@ -287,6 +287,16 @@ def __init__(
287287        )
288288
289289        attn_backend  =  pytorch_backend_config .attn_backend 
290+ 
291+         self .lora_manager : Optional [LoraManager ] =  None 
292+         if  lora_config  is  not None :
293+             self .lora_manager  =  LoraManager ()
294+ 
295+         self .lora_prefetch_requests_list  =  None   # TODO smor - fix "LoRARequest" import 
296+         if  lora_config  is  not None  and  lora_config .lora_request  is  not None :
297+             self .lora_prefetch_requests_list  =  lora_config .lora_request 
298+             self .has_lora_prefetched  =  False 
299+ 
290300        self .model  =  self ._load_model (
291301            model_path ,
292302            mapping = self .mapping ,
@@ -445,6 +455,27 @@ def set_lora_model_config(self, lora_target_modules: list[str],
445455            hidden_size = self .model .config .hidden_size ,
446456            dtype = torch_dtype_to_str (self .model .config .torch_dtype ))
447457
458+     def  set_lora_manager_cpp_peft_cache_manager (
459+             self , resource_manager : ResourceManager ):
460+         cpp_peft_cache_manager  =  resource_manager .get_resource_manager (
461+             ResourceManagerType .PEFT_CACHE_MANAGER )
462+         if  cpp_peft_cache_manager  is  not None  and  self .lora_manager  is  not None :
463+             self .lora_manager .set_cpp_peft_cache_manager (
464+                 cpp_peft_cache_manager .impl )
465+ 
466+     def  prefetch_lora_dirs (self ):
467+         if  self .lora_prefetch_requests_list  is  None :
468+             return 
469+ 
470+         for  request  in  self .lora_prefetch_requests_list :
471+             self .lora_manager .load_from_ckpt (
472+                 [request .path ],
473+                 model_config = self .lora_model_config ,
474+                 runtime_mapping = None ,
475+                 uids = [request .adapter_id ])
476+ 
477+         self .has_lora_prefetched  =  True 
478+ 
448479    @property  
449480    def  use_mrope (self ):
450481        use_mrope  =  False 
@@ -503,6 +534,16 @@ def warmup(self, resource_manager: ResourceManager) -> None:
503534        self .cuda_graph_dummy_request  =  None 
504535
505536        def  get_cuda_graph_warmup_request (batch_size , draft_len ):
537+             lora_config  =  None 
538+             if  self .has_lora_prefetched :
539+                 # TODO smor currently I assume a single adapter with uid 0, change this 
540+                 uid  =  0 
541+                 from  tensorrt_llm .bindings  import  executor  as  tllm 
542+                 lora_config  =  tllm .LoraConfig (
543+                     task_id = uid ,
544+                     weights = self .lora_manager .cpp_lora_weights [uid ],
545+                     config = self .lora_manager .cpp_lora_config [uid ])
546+ 
506547            # Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel. 
507548            available_blocks  =  kv_cache_manager .get_num_free_blocks (
508549            ) //  self .max_beam_width 
@@ -516,7 +557,10 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
516557                    is_gen = True ,
517558                    max_num_draft_tokens = draft_len ,
518559                    use_mrope = use_mrope ,
519-                     max_beam_width = self .max_beam_width )
560+                     max_beam_width = self .max_beam_width ,
561+                     lora_request = 
562+                     lora_config ,  # TODO smor- tests assume BS1 then this will be ignored for now, need to resolve 
563+                 )
520564                # Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request. 
521565                available_tokens  =  kv_cache_manager .get_num_available_tokens (
522566                    draft_len )
@@ -530,7 +574,8 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
530574                    is_gen = True ,
531575                    max_num_draft_tokens = draft_len ,
532576                    use_mrope = use_mrope ,
533-                     max_beam_width = self .max_beam_width )[0 ]
577+                     max_beam_width = self .max_beam_width ,
578+                     lora_request = lora_config )[0 ]
534579                # Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case. 
535580                # This batch contains both the longest request and the shortest requests, 
536581                # it also contains the maximum number of requests and the maximum token number, 
@@ -926,6 +971,7 @@ def _round_up_batch_size(self, batch_size: int) -> int:
926971    def  _maybe_get_cuda_graph (
927972        self ,
928973        batch : ScheduledRequests ,
974+         resource_manager : Optional [ResourceManager ] =  None 
929975    ) ->  Optional [DecodingCUDAGraphRunner ]:
930976        """ 
931977        Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled 
@@ -972,13 +1018,60 @@ def _maybe_get_cuda_graph(
9721018        else :
9731019            spec_metadata  =  None 
9741020
1021+         lora_params  =  None 
1022+         if  self .has_lora_prefetched :
1023+             peft_cache_manager  =  resource_manager .get_resource_manager (
1024+                 ResourceManagerType .PEFT_CACHE_MANAGER )
1025+ 
1026+             context_requests  =  batch .context_requests 
1027+             generation_requests  =  batch .generation_requests 
1028+ 
1029+             if  len (context_requests ) >  0  and  len (generation_requests ) >  0 :
1030+                 raise  ValueError (
1031+                     "SMOR, non empty context and generation requests isn't tested yet" 
1032+                 )
1033+ 
1034+             if  len (context_requests ) >  0 :
1035+                 raise  ValueError ("SMOR, context requests isn't tested yet" )
1036+ 
1037+             if  len (generation_requests ) >  1 :
1038+                 raise  ValueError ("SMOR, generation requests isn't tested yet" )
1039+ 
1040+             generation_request  =  generation_requests [0 ]
1041+             # TODO smor I have no idea why this is happening 
1042+             generation_request .lora_weights  =  generation_request .lora_weights .reshape (
1043+                 [1 ] +  list (generation_request .lora_weights .shape ))
1044+             generation_request .lora_config  =  generation_request .lora_config .reshape (
1045+                 [1 ] +  list (generation_request .lora_config .shape ))
1046+             peft_cache_manager .impl .add_request_peft (generation_request , True )
1047+ 
1048+             py_lora_task_layer_module_configs  =  peft_cache_manager .impl .ensure_batch (
1049+                 context_requests , generation_requests , False )
1050+             for  req  in  context_requests :
1051+                 req .py_lora_task_layer_module_configs  =  py_lora_task_layer_module_configs [
1052+                     req .
1053+                     py_request_id ] if  req .py_request_id  in  py_lora_task_layer_module_configs  else  None 
1054+             for  req  in  generation_requests :
1055+                 req .py_lora_task_layer_module_configs  =  py_lora_task_layer_module_configs [
1056+                     req .
1057+                     py_request_id ] if  req .py_request_id  in  py_lora_task_layer_module_configs  else  None 
1058+ 
1059+             # TODO smor - look at get lora params from requests 
1060+             # You need something that isn't scheduled requests 
1061+             # It also appears that you should make sure resource manager is called, because prefetch 
1062+             # has to be added to peftCacheManager as well. So it still shouldn't work 
1063+ 
1064+             lora_params  =  self ._get_lora_params_from_requests (
1065+                 batch , attn_metadata )
1066+             print (f"SMOR, not failed on lora_params in maybe_get_cuda_graph" )
1067+ 
9751068        # Initialize nested dictionary if needed 
9761069        if  batch_size  not  in self ._cuda_graphs :
9771070            self ._cuda_graphs [batch_size ] =  {}
9781071
9791072        self ._cuda_graphs [batch_size ][draft_len ] =  DecodingCUDAGraphRunner (
9801073            num_sequences_in_batch , "cuda" , attn_metadata , spec_metadata ,
981-             self .use_mrope )
1074+             self .use_mrope ,  lora_params )
9821075        return  self ._cuda_graphs [batch_size ][draft_len ]
9831076
9841077    def  __del__ (self ) ->  None :
@@ -2134,7 +2227,8 @@ def forward(
21342227                                          gather_context_logits )
21352228        with  self ._maybe_pad_batch (scheduled_requests , kv_cache_manager ,
21362229                                   spec_resource_manager ) as  scheduled_requests :
2137-             maybe_graph  =  self ._maybe_get_cuda_graph (scheduled_requests )
2230+             maybe_graph  =  self ._maybe_get_cuda_graph (
2231+                 scheduled_requests , resource_manager = resource_manager )
21382232            if  maybe_graph  is  not None :
21392233                attn_metadata  =  maybe_graph .attn_metadata 
21402234                spec_metadata  =  maybe_graph .spec_metadata 
0 commit comments