@@ -180,11 +180,14 @@ def _get_dummy_mm_inputs(
180180 def _get_mm_num_tokens (
181181 self ,
182182 mm_inputs : MultiModalInputs ,
183+ mm_embeddings_only : bool = True ,
183184 ) -> Mapping [str , int ]:
184185 placeholders_by_modality = mm_inputs ["mm_placeholders" ]
185186
186187 return {
187- modality : sum (item .length for item in placeholders )
188+ modality :
189+ sum (item .get_num_embeds () if mm_embeddings_only else item .length
190+ for item in placeholders )
188191 for modality , placeholders in placeholders_by_modality .items ()
189192 }
190193
@@ -257,6 +260,7 @@ def get_mm_max_tokens(
257260 self ,
258261 seq_len : int ,
259262 mm_counts : Optional [Mapping [str , int ]] = None ,
263+ mm_embeddings_only : bool = True ,
260264 ) -> Mapping [str , int ]:
261265 if mm_counts is None :
262266 mm_counts = self .get_mm_limits ()
@@ -285,4 +289,14 @@ def get_mm_max_tokens(
285289 return max_tokens_per_item
286290
287291 mm_inputs = self ._get_dummy_mm_inputs (seq_len , mm_counts )
288- return self ._get_mm_num_tokens (mm_inputs )
292+ return self ._get_mm_num_tokens (mm_inputs ,
293+ mm_embeddings_only = mm_embeddings_only )
294+
295+ def get_max_placeholder_tokens (
296+ self ,
297+ seq_len : int ,
298+ mm_counts : Optional [Mapping [str , int ]] = None ,
299+ ):
300+ return self .get_mm_max_tokens (seq_len ,
301+ mm_counts ,
302+ mm_embeddings_only = False )
0 commit comments