Skip to content

Commit 8da6645

Browse files
committed
chore: Define new get_max_placeholder_tokens
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent 1912554 commit 8da6645

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

vllm/multimodal/profiling.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,14 @@ def _get_dummy_mm_inputs(
180180
def _get_mm_num_tokens(
181181
self,
182182
mm_inputs: MultiModalInputs,
183+
mm_embeddings_only: bool = True,
183184
) -> Mapping[str, int]:
184185
placeholders_by_modality = mm_inputs["mm_placeholders"]
185186

186187
return {
187-
modality: sum(item.length for item in placeholders)
188+
modality:
189+
sum(item.get_num_embeds() if mm_embeddings_only else item.length
190+
for item in placeholders)
188191
for modality, placeholders in placeholders_by_modality.items()
189192
}
190193

@@ -257,6 +260,7 @@ def get_mm_max_tokens(
257260
self,
258261
seq_len: int,
259262
mm_counts: Optional[Mapping[str, int]] = None,
263+
mm_embeddings_only: bool = True,
260264
) -> Mapping[str, int]:
261265
if mm_counts is None:
262266
mm_counts = self.get_mm_limits()
@@ -285,4 +289,14 @@ def get_mm_max_tokens(
285289
return max_tokens_per_item
286290

287291
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
288-
return self._get_mm_num_tokens(mm_inputs)
292+
return self._get_mm_num_tokens(mm_inputs,
293+
mm_embeddings_only=mm_embeddings_only)
294+
295+
def get_max_placeholder_tokens(
296+
self,
297+
seq_len: int,
298+
mm_counts: Optional[Mapping[str, int]] = None,
299+
):
300+
return self.get_mm_max_tokens(seq_len,
301+
mm_counts,
302+
mm_embeddings_only=False)

vllm/multimodal/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def get_max_tokens_per_item_by_modality(
129129
seq_len = model_config.max_model_len
130130
mm_limits = self.get_mm_limits_per_prompt(model_config)
131131

132-
return profiler.get_mm_max_tokens(
132+
return profiler.get_max_placeholder_tokens(
133133
seq_len,
134134
{
135135
modality: 1

0 commit comments

Comments
 (0)