@@ -289,10 +289,10 @@ def __init__(
289289        attn_backend  =  pytorch_backend_config .attn_backend 
290290
291291        self .lora_manager : Optional [LoraManager ] =  None 
292-         if   lora_config   is   not   None : 
293-              self . lora_manager   =   LoraManager () 
294- 
295-         self . lora_prefetch_requests_list   =   None    # TODO smor - fix "LoRARequest" import  
292+         self . lora_prefetch_requests_list   =   None 
293+         # TODO smor- do we want to get the request inside the lora config? 
294+          # TODO smor- what happens if you get target modules? 
295+         # TODO smor- answer and guard against this  
296296        if  lora_config  is  not None  and  lora_config .lora_request  is  not None :
297297            self .lora_prefetch_requests_list  =  lora_config .lora_request 
298298            self .has_lora_prefetched  =  False 
@@ -455,13 +455,11 @@ def set_lora_model_config(self, lora_target_modules: list[str],
455455            hidden_size = self .model .config .hidden_size ,
456456            dtype = torch_dtype_to_str (self .model .config .torch_dtype ))
457457
458-     def  set_lora_manager_cpp_peft_cache_manager (
459-             self , resource_manager : ResourceManager ):
460-         cpp_peft_cache_manager  =  resource_manager .get_resource_manager (
458+     def  set_lora_manager (self , resource_manager : ResourceManager ):
459+         peft_cache_manager  =  resource_manager .get_resource_manager (
461460            ResourceManagerType .PEFT_CACHE_MANAGER )
462-         if  cpp_peft_cache_manager  is  not None  and  self .lora_manager  is  not None :
463-             self .lora_manager .set_cpp_peft_cache_manager (
464-                 cpp_peft_cache_manager .impl )
461+         if  peft_cache_manager  is  not None :
462+             self .lora_manager  =  peft_cache_manager .get_lora_manager ()
465463
466464    def  prefetch_lora_dirs (self ):
467465        if  self .lora_prefetch_requests_list  is  None :
@@ -534,15 +532,34 @@ def warmup(self, resource_manager: ResourceManager) -> None:
534532        self .cuda_graph_dummy_request  =  None 
535533
536534        def  get_cuda_graph_warmup_request (batch_size , draft_len ):
537-             lora_config  =  None 
535+             lora_configs  =  [] 
538536            if  self .has_lora_prefetched :
539-                 # TODO smor currently I assume a single adapter with uid 0, change this 
540-                 uid  =  0 
537+                 print (
538+                     "SMOR, model engine, maybe get cuda graph, processing lora_params" 
539+                 )
540+                 # from IPython import embed 
541+                 # embed() 
541542                from  tensorrt_llm .bindings  import  executor  as  tllm 
542-                 lora_config  =  tllm .LoraConfig (
543-                     task_id = uid ,
544-                     weights = self .lora_manager .cpp_lora_weights [uid ],
545-                     config = self .lora_manager .cpp_lora_config [uid ])
543+ 
544+                 # TODO smor- what happens if batch size > len(available_uids)? 
545+                 available_uids  =  list (self .lora_manager .cpp_lora_weights .keys ())
546+                 available_uids .sort ()  # Ensure consistent ordering 
547+ 
548+                 # Create LoRA configs for each request in the batch 
549+                 # IMPORTANT: Match request_id to the corresponding LoRA UID 
550+                 for  request_id  in  range (batch_size ):
551+                     # Use request_id as the LoRA UID (assuming they should match) 
552+                     # This ensures request 0 uses LoRA UID 0, request 1 uses LoRA UID 1, etc. 
553+                     uid  =  available_uids [request_id  %  len (available_uids )]
554+ 
555+                     # Get the tensors - executor LoraConfig expects 2D tensors 
556+                     weights  =  self .lora_manager .cpp_lora_weights [uid ]
557+                     config  =  self .lora_manager .cpp_lora_config [uid ]
558+ 
559+                     lora_config  =  tllm .LoraConfig (task_id = uid ,
560+                                                   weights = weights ,
561+                                                   config = config )
562+                     lora_configs .append (lora_config )
546563
547564            # Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel. 
548565            available_blocks  =  kv_cache_manager .get_num_free_blocks (
@@ -552,14 +569,18 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
552569                result .context_requests  =  []
553570                # Add (batch_size - 1) dummy requests with seq_len=1. 
554571                # Should only need one more page per request. 
572+ 
573+                 # Use the first batch_size-1 LoRA configs for the short requests 
574+                 short_requests_lora  =  lora_configs [:batch_size  - 
575+                                                    1 ] if  lora_configs  else  None 
576+ 
555577                requests  =  kv_cache_manager .add_dummy_requests (
556578                    list (range (batch_size  -  1 )),
557579                    is_gen = True ,
558580                    max_num_draft_tokens = draft_len ,
559581                    use_mrope = use_mrope ,
560582                    max_beam_width = self .max_beam_width ,
561-                     lora_request = 
562-                     lora_config ,  # TODO smor- tests assume BS1 then this will be ignored for now, need to resolve 
583+                     lora_request = short_requests_lora ,
563584                )
564585                # Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request. 
565586                available_tokens  =  kv_cache_manager .get_num_available_tokens (
@@ -568,14 +589,20 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
568589                # Add one dummy request with the maximum possible sequence length. 
569590                # The sequence length is limited by both the max_seq_len and the number of available blocks. 
570591                token_num  =  max (1 , min (available_tokens , self .max_seq_len  -  1 ))
592+ 
593+                 # Use the last LoRA config for the max sequence length request 
594+                 lora_request_for_max  =  [
595+                     lora_configs [batch_size  -  1 ]
596+                 ] if  lora_configs  and  len (lora_configs ) >=  batch_size  else  None 
597+ 
571598                max_seq_len_request  =  kv_cache_manager .add_dummy_requests (
572599                    request_ids = [batch_size  -  1 ],
573600                    token_nums = [token_num ],
574601                    is_gen = True ,
575602                    max_num_draft_tokens = draft_len ,
576603                    use_mrope = use_mrope ,
577604                    max_beam_width = self .max_beam_width ,
578-                     lora_request = lora_config )[0 ]
605+                     lora_request = lora_request_for_max )[0 ]
579606                # Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case. 
580607                # This batch contains both the longest request and the shortest requests, 
581608                # it also contains the maximum number of requests and the maximum token number, 
@@ -693,7 +720,7 @@ def release_batch(result: ScheduledRequests | None):
693720            return 
694721
695722        with  contextlib .ExitStack () as  stack :
696-             if  self ._torch_compile_enabled :
723+             if  self ._torch_compile_enabled :   # TODO SMOR False 
697724
698725                def  disable_optimization (backend : Backend ):
699726                    # Disable torch.compile optimization and fallback to eager execution 
@@ -733,7 +760,7 @@ def disable_optimization(backend: Backend):
733760                                             resource_manager = resource_manager )
734761                                torch .cuda .synchronize ()
735762
736-             if  self .pytorch_backend_config .enable_autotuner :
763+             if  self .pytorch_backend_config .enable_autotuner :   # TODO SMOR True, currently get_autotune_warmup_request isn't addressed 
737764                with  self .no_cuda_graph (), autotune ():
738765                    result  =  get_autotune_warmup_request ()
739766                    with  release_batch (result ) as  batch :
@@ -787,10 +814,13 @@ def disable_optimization(backend: Backend):
787814                            f"Run generation only CUDA graph warmup for batch size={ bs } { draft_len }  
788815                        )
789816                        self .enable_spec_decode  =  draft_len  >  0  or  self .is_draft_model 
817+                         print ("SMOR, model engine, begore forward" )
818+                         # from IPython import embed 
819+                         # embed() 
790820                        self .forward (batch ,
791821                                     new_tensors_device = None ,
792822                                     resource_manager = resource_manager )
793-                         torch .cuda .synchronize ()
823+                         torch .cuda .synchronize ()   # fails here 
794824
795825            if  self ._torch_compile_piecewise_cuda_graph  and  self ._torch_compile_enabled :
796826                for  seq_lens  in  cuda_graph_batch_sizes :
@@ -1034,16 +1064,8 @@ def _maybe_get_cuda_graph(
10341064            if  len (context_requests ) >  0 :
10351065                raise  ValueError ("SMOR, context requests isn't tested yet" )
10361066
1037-             if  len (generation_requests ) >  1 :
1038-                 raise  ValueError ("SMOR, generation requests isn't tested yet" )
1039- 
1040-             generation_request  =  generation_requests [0 ]
1041-             # TODO smor I have no idea why this is happening 
1042-             generation_request .lora_weights  =  generation_request .lora_weights .reshape (
1043-                 [1 ] +  list (generation_request .lora_weights .shape ))
1044-             generation_request .lora_config  =  generation_request .lora_config .reshape (
1045-                 [1 ] +  list (generation_request .lora_config .shape ))
1046-             peft_cache_manager .impl .add_request_peft (generation_request , True )
1067+             for  generation_request  in  generation_requests :
1068+                 peft_cache_manager .add_request_peft (generation_request )
10471069
10481070            py_lora_task_layer_module_configs  =  peft_cache_manager .impl .ensure_batch (
10491071                context_requests , generation_requests , False )
@@ -1056,14 +1078,8 @@ def _maybe_get_cuda_graph(
10561078                    req .
10571079                    py_request_id ] if  req .py_request_id  in  py_lora_task_layer_module_configs  else  None 
10581080
1059-             # TODO smor - look at get lora params from requests 
1060-             # You need something that isn't scheduled requests 
1061-             # It also appears that you should make sure resource manager is called, because prefetch 
1062-             # has to be added to peftCacheManager as well. So it still shouldn't work 
1063- 
10641081            lora_params  =  self ._get_lora_params_from_requests (
10651082                batch , attn_metadata )
1066-             print (f"SMOR, not failed on lora_params in maybe_get_cuda_graph" )
10671083
10681084        # Initialize nested dictionary if needed 
10691085        if  batch_size  not  in self ._cuda_graphs :
@@ -1271,7 +1287,8 @@ def _prepare_tp_inputs(
12711287            attn_metadata : AttentionMetadata ,
12721288            spec_metadata : Optional [SpecMetadata ] =  None ,
12731289            new_tensors_device : Optional [SampleStateTensors ] =  None ,
1274-             cache_indirection_buffer : Optional [torch .Tensor ] =  None ):
1290+             cache_indirection_buffer : Optional [torch .Tensor ] =  None ,
1291+             lora_params : Optional [dict ] =  None ):
12751292        """ 
12761293        Prepare inputs for Pytorch Model. 
12771294        """ 
@@ -1624,8 +1641,9 @@ def previous_seq_slots_device():
16241641
16251642        attn_metadata .prepare ()
16261643
1627-         lora_params  =  self ._get_lora_params_from_requests (
1628-             scheduled_requests , attn_metadata )
1644+         if  lora_params  is  None :
1645+             lora_params  =  self ._get_lora_params_from_requests (
1646+                 scheduled_requests , attn_metadata )
16291647
16301648        # Prepare inputs 
16311649        inputs  =  {
@@ -1650,6 +1668,9 @@ def previous_seq_slots_device():
16501668                    mrope_position_deltas_list , dim = 0 )
16511669
16521670        if  bool (lora_params ):
1671+             print ("SMOR, model engine, before setting lora_params" )
1672+             # from IPython import embed 
1673+             # embed() 
16531674            inputs ['lora_params' ] =  lora_params 
16541675
16551676        if  spec_metadata  is  not None :
@@ -2060,7 +2081,13 @@ def _get_lora_params_from_requests(self,
20602081        tmp_lora_params  =  {}
20612082
20622083        request_list  =  scheduled_requests .context_requests  +  scheduled_requests .generation_requests 
2063- 
2084+         if  len (request_list ) ==  2 :
2085+             print (
2086+                 "SMOR, after getting request_list in get lora params from requests, check for some order reversal" 
2087+             )
2088+             # from IPython import embed 
2089+             # embed() 
2090+             # request_list = request_list[::-1] 
20642091        # trace all requests to get the union set of the lora params 
20652092        for  request  in  request_list :
20662093            if  request .py_lora_task_layer_module_configs  is  None :
@@ -2158,14 +2185,14 @@ def _get_lora_params_from_requests(self,
21582185        return  lora_params 
21592186
21602187    @nvtx_range ("_prepare_inputs" ) 
2161-     def  _prepare_inputs (
2162-             self ,
2163-             scheduled_requests :  ScheduledRequests ,
2164-             kv_cache_manager :  KVCacheManager ,
2165-             attn_metadata :  AttentionMetadata ,
2166-             spec_metadata : Optional [SpecMetadata ] =  None ,
2167-             new_tensors_device : Optional [SampleStateTensors ] =  None ,
2168-             cache_indirection_buffer : Optional [torch . Tensor ] =  None ):
2188+     def  _prepare_inputs (self , 
2189+                          scheduled_requests :  ScheduledRequests ,
2190+                          kv_cache_manager :  KVCacheManager ,
2191+                          attn_metadata :  AttentionMetadata ,
2192+                          spec_metadata :  Optional [ SpecMetadata ]  =   None ,
2193+                          new_tensors_device : Optional [SampleStateTensors ] =  None ,
2194+                          cache_indirection_buffer : Optional [torch . Tensor ] =  None ,
2195+                          lora_params : Optional [dict ] =  None ):
21692196        if  self .mapping  is  not None  and  'cp_type'  in  self .mapping .cp_config :
21702197            cp_type  =  self .mapping .cp_config ['cp_type' ]
21712198            if  'star_attention'  ==  cp_type :
@@ -2177,7 +2204,8 @@ def _prepare_inputs(
21772204            return  self ._prepare_tp_inputs (scheduled_requests , kv_cache_manager ,
21782205                                           attn_metadata , spec_metadata ,
21792206                                           new_tensors_device ,
2180-                                            cache_indirection_buffer )
2207+                                            cache_indirection_buffer ,
2208+                                            lora_params )
21812209
21822210    @torch .inference_mode () 
21832211    @with_model_extra_attrs (lambda  self : self .model .extra_attrs ) 
@@ -2232,16 +2260,19 @@ def forward(
22322260            if  maybe_graph  is  not None :
22332261                attn_metadata  =  maybe_graph .attn_metadata 
22342262                spec_metadata  =  maybe_graph .spec_metadata 
2263+                 lora_params  =  maybe_graph .lora_params 
22352264            else :
22362265                attn_metadata  =  self .attn_metadata 
2266+                 lora_params  =  None 
22372267                if  self .enable_spec_decode :
22382268                    spec_metadata  =  self .spec_metadata 
22392269                else :
22402270                    spec_metadata  =  None 
22412271
22422272            inputs , gather_ids  =  self ._prepare_inputs (
22432273                scheduled_requests , kv_cache_manager , attn_metadata ,
2244-                 spec_metadata , new_tensors_device , cache_indirection_buffer )
2274+                 spec_metadata , new_tensors_device , cache_indirection_buffer ,
2275+                 lora_params )
22452276
22462277            self .iter_counter  +=  1 
22472278
0 commit comments