1717# and s2wrapper: https://github.com/bfshi/scaling_on_scales 
1818
1919import  math 
20- from  typing  import  List , Optional , Tuple ,  Callable ,  Dict ,  Any ,  Union 
20+ from  typing  import  List , Optional , Tuple 
2121
2222import  torch 
2323import  torch .nn .functional  as  F 
3131
3232
3333def  _get_active_multimodal_params (
34-     multimodal_params : List [MultimodalParams ],
35- ) ->  List [MultimodalParams ]:
34+     multimodal_params : List [MultimodalParams ], ) ->  List [MultimodalParams ]:
3635    """ 
3736    Get active multimodal params that need encoder processing for chunk prefill. 
3837    """ 
@@ -44,10 +43,12 @@ def _get_active_multimodal_params(
4443            continue 
4544
4645        # Check if embeddings are already cached 
47-         if  (param .multimodal_data  and 
48-             "multimodal_embedding"  in  param .multimodal_data  and 
49-             param .multimodal_data ["multimodal_embedding" ] is  not None ):
50-             logger .debug (f"Skipping encoder forward for param with cached multimodal_embedding" )
46+         if  (param .multimodal_data 
47+                 and  "multimodal_embedding"  in  param .multimodal_data 
48+                 and  param .multimodal_data ["multimodal_embedding" ] is  not None ):
49+             logger .debug (
50+                 f"Skipping encoder forward for param with cached multimodal_embedding" 
51+             )
5152            continue 
5253
5354        # This param needs encoder processing 
@@ -65,11 +66,15 @@ def _cache_multimodal_embeddings(
6566    Uses torch.split for efficient tensor splitting without manual indexing. 
6667    """ 
6768    # TODO: support multiple multimodal modalities per request 
68-     assert  len (embeddings ) ==  1 , "Currently only support single mm_embeds (single modality) per request" 
69+     assert  len (
70+         embeddings 
71+     ) ==  1 , "Currently only support single mm_embeds (single modality) per request" 
6972    mm_embed  =  embeddings [0 ]
7073
7174    # Collect embedding lengths for each parameter 
72-     embed_lengths  =  [param .multimodal_runtime .total_mm_tokens  for  param  in  multimodal_params ]
75+     embed_lengths  =  [
76+         param .multimodal_runtime .total_mm_tokens  for  param  in  multimodal_params 
77+     ]
7378
7479    # Validate total length matches 
7580    total_expected  =  sum (embed_lengths )
@@ -83,7 +88,9 @@ def _cache_multimodal_embeddings(
8388    for  param , embed_chunk  in  zip (multimodal_params , split_embeddings ):
8489        param .multimodal_data ["multimodal_embedding" ] =  embed_chunk 
8590
86-     logger .debug (f"Cached { len (split_embeddings )}  )
91+     logger .debug (
92+         f"Cached { len (split_embeddings )}  
93+     )
8794
8895
8996def  get_multimodal_embeddings (
@@ -111,9 +118,7 @@ def get_multimodal_embeddings(
111118        return  []
112119
113120    # Step 1: Find active multimodal params that need encoder processing 
114-     active_multimodal_params  =  _get_active_multimodal_params (
115-         multimodal_params 
116-     )
121+     active_multimodal_params  =  _get_active_multimodal_params (multimodal_params )
117122
118123    # Step 2: Run encoder forward only on uncached parameters 
119124    if  active_multimodal_params :
@@ -124,19 +129,24 @@ def get_multimodal_embeddings(
124129            return  encoder_outputs 
125130
126131        # Validate that multimodal_runtime has required attributes for caching 
127-         if  (not  hasattr (active_multimodal_params [0 ], 'multimodal_runtime' ) or 
128-             active_multimodal_params [0 ].multimodal_runtime  is  None  or 
129-             active_multimodal_params [0 ].multimodal_runtime .total_mm_tokens  is  None ):
130-             logger .warning ("Multimodal runtime data missing or incomplete - recomputed all embeddings" )
132+         if  (not  hasattr (active_multimodal_params [0 ], 'multimodal_runtime' )
133+                 or  active_multimodal_params [0 ].multimodal_runtime  is  None  or 
134+                 active_multimodal_params [0 ].multimodal_runtime .total_mm_tokens 
135+                 is  None ):
136+             logger .warning (
137+                 "Multimodal runtime data missing or incomplete - recomputed all embeddings" 
138+             )
131139            return  encoder_outputs 
132140
133141        # Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"] 
134-         _cache_multimodal_embeddings (
135-             active_multimodal_params , encoder_outputs 
136-         )
142+         _cache_multimodal_embeddings (active_multimodal_params , encoder_outputs )
137143
138144    # Step 4: Gather all embeddings for the batch 
139-     all_embeddings  =  torch .cat ([param .multimodal_data ["multimodal_embedding" ] for  param  in  multimodal_params ], dim = 0 )
145+     all_embeddings  =  torch .cat ([
146+         param .multimodal_data ["multimodal_embedding" ]
147+         for  param  in  multimodal_params 
148+     ],
149+                                dim = 0 )
140150    return  [all_embeddings ]
141151
142152
@@ -176,28 +186,27 @@ def find_input_mm_embeds(
176186        return  mm_embeds 
177187
178188    # Calculate total tokens that need processing (both cached and current chunk) 
179-     total_mm_tokens  =  sum ([
180-         param .multimodal_runtime .num_mm_tokens 
181-         for  param  in  multimodal_params 
182-     ])
189+     total_mm_tokens  =  sum (
190+         [param .multimodal_runtime .num_mm_tokens  for  param  in  multimodal_params ])
183191
184192    if  total_mm_tokens  ==  0 :
185193        # No tokens need processing, return empty list 
186194        logger .debug (
187-             "All multimodal tokens are cached or beyond current chunk, skipping vision encoder forward" )
195+             "All multimodal tokens are cached or beyond current chunk, skipping vision encoder forward" 
196+         )
188197        return  []
189198
190199    if  total_mm_tokens  ==  sum (mm_embed .shape [0 ] for  mm_embed  in  mm_embeds ):
191200        return  mm_embeds 
192201
193- 
194202    current_pos  =  0 
195203    slices  =  []
196204    for  param  in  multimodal_params :
197205        runtime  =  param .multimodal_runtime 
198-         slices .append ((current_pos  +  runtime .num_unseen_mm_tokens ,
199-                        current_pos  +  runtime .num_unseen_mm_tokens  +  runtime .num_mm_tokens ))
200-         if  len (mm_embeds ) ==  1 :  # pre-concatenated mm_embeds, need global offset 
206+         slices .append ((current_pos  +  runtime .num_unseen_mm_tokens , current_pos  + 
207+                        runtime .num_unseen_mm_tokens  +  runtime .num_mm_tokens ))
208+         if  len (mm_embeds 
209+                ) ==  1 :  # pre-concatenated mm_embeds, need global offset 
201210            current_pos  +=  runtime .total_mm_tokens 
202211
203212    sliced_mm_embeds  =  []
0 commit comments