3030from  tensorrt_llm .logger  import  logger 
3131
3232
33- def  _get_active_multimodal_params (
33+ def  _get_uncached_multimodal_params (
3434    multimodal_params : List [MultimodalParams ], ) ->  List [MultimodalParams ]:
3535    """ 
36-     Get active  multimodal params that need encoder processing for chunk prefill. 
36+     Get uncached  multimodal params that need encoder processing for chunk prefill. 
3737    """ 
3838    params_to_run  =  []
3939
@@ -63,7 +63,8 @@ def _cache_multimodal_embeddings(
6363) ->  None :
6464    """ 
6565    Cache computed multimodal embeddings back to multimodal_data to avoid recomputation. 
66-     Uses torch.split for efficient tensor splitting without manual indexing. 
66+     Note this function only caches multimodal embeddings within the current request context, 
67+     mostly for chunked prefill. It does not persist embeddings across different requests or sessions. 
6768    """ 
6869    # TODO: support multiple multimodal modalities per request 
6970    assert  len (
@@ -73,7 +74,8 @@ def _cache_multimodal_embeddings(
7374
7475    # Collect embedding lengths for each parameter 
7576    embed_lengths  =  [
76-         param .multimodal_runtime .total_mm_tokens  for  param  in  multimodal_params 
77+         param .multimodal_runtime .total_mm_tokens_in_request 
78+         for  param  in  multimodal_params 
7779    ]
7880
7981    # Validate total length matches 
@@ -117,29 +119,31 @@ def get_multimodal_embeddings(
117119    if  not  multimodal_params :
118120        return  []
119121
120-     # Step 1: Find active multimodal params that need encoder processing 
121-     active_multimodal_params  =  _get_active_multimodal_params (multimodal_params )
122+     # Step 1: Find uncached multimodal params that need encoder processing 
123+     uncached_multimodal_params  =  _get_uncached_multimodal_params (
124+         multimodal_params )
122125
123126    # Step 2: Run encoder forward only on uncached parameters 
124-     if  active_multimodal_params :
125-         encoder_outputs  =  encoder_forward_fn (active_multimodal_params )
127+     if  uncached_multimodal_params :
128+         encoder_outputs  =  encoder_forward_fn (uncached_multimodal_params )
126129
127130        # TODO: support multiple multimodal modalities per request 
128131        if  len (encoder_outputs ) >  1 :
129132            return  encoder_outputs 
130133
131134        # Validate that multimodal_runtime has required attributes for caching 
132-         if  (not  hasattr (active_multimodal_params [0 ], 'multimodal_runtime' )
133-                 or  active_multimodal_params [0 ].multimodal_runtime  is  None   or 
134-                 active_multimodal_params [0 ].multimodal_runtime .total_mm_tokens 
135-                 is  None ):
135+         if  (not  hasattr (uncached_multimodal_params [0 ], 'multimodal_runtime' )
136+                 or  uncached_multimodal_params [0 ].multimodal_runtime  is  None 
137+                 or   uncached_multimodal_params [0 ].multimodal_runtime .
138+                 total_mm_tokens_in_request   is  None ):
136139            logger .warning (
137140                "Multimodal runtime data missing or incomplete - recomputed all embeddings" 
138141            )
139142            return  encoder_outputs 
140143
141144        # Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"] 
142-         _cache_multimodal_embeddings (active_multimodal_params , encoder_outputs )
145+         _cache_multimodal_embeddings (uncached_multimodal_params ,
146+                                      encoder_outputs )
143147
144148    # Step 4: Gather all embeddings for the batch 
145149    all_embeddings  =  torch .cat ([
@@ -186,8 +190,10 @@ def find_input_mm_embeds(
186190        return  mm_embeds 
187191
188192    # Calculate total tokens that need processing (both cached and current chunk) 
189-     total_mm_tokens  =  sum (
190-         [param .multimodal_runtime .num_mm_tokens  for  param  in  multimodal_params ])
193+     total_mm_tokens  =  sum ([
194+         param .multimodal_runtime .num_mm_tokens_in_chunk 
195+         for  param  in  multimodal_params 
196+     ])
191197
192198    if  total_mm_tokens  ==  0 :
193199        # No tokens need processing, return empty list 
@@ -203,11 +209,12 @@ def find_input_mm_embeds(
203209    slices  =  []
204210    for  param  in  multimodal_params :
205211        runtime  =  param .multimodal_runtime 
206-         slices .append ((current_pos  +  runtime .num_unseen_mm_tokens , current_pos  + 
207-                        runtime .num_unseen_mm_tokens  +  runtime .num_mm_tokens ))
212+         slices .append (
213+             (current_pos  +  runtime .num_unseen_mm_tokens , current_pos  + 
214+              runtime .num_unseen_mm_tokens  +  runtime .num_mm_tokens_in_chunk ))
208215        if  len (mm_embeds 
209216               ) ==  1 :  # pre-concatenated mm_embeds, need global offset 
210-             current_pos  +=  runtime .total_mm_tokens 
217+             current_pos  +=  runtime .total_mm_tokens_in_request 
211218
212219    sliced_mm_embeds  =  []
213220    if  len (mm_embeds ) ==  1 :
0 commit comments