Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,14 @@ def _get_dummy_mm_inputs(
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]

return {
modality: sum(item.get_num_embeds() for item in placeholders)
modality:
sum(item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}

Expand Down Expand Up @@ -253,10 +256,11 @@ def get_decoder_dummy_data(
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)

def get_mm_max_tokens(
def _get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
Expand Down Expand Up @@ -285,4 +289,25 @@ def get_mm_max_tokens(
return max_tokens_per_item

mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs)
return self._get_mm_num_tokens(mm_inputs,
mm_embeddings_only=mm_embeddings_only)

def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
):
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.

<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>
Returns 9, even when the number of image embeddings is 6.

This is important to take into account when profiling and
initializing the encoder cache size.
"""

return self._get_mm_max_tokens(seq_len,
mm_counts,
mm_embeddings_only=False)
2 changes: 1 addition & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_max_tokens_per_item_by_modality(
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)

return profiler.get_mm_max_tokens(
return profiler.get_mm_max_contiguous_tokens(
seq_len,
{
modality: 1
Expand Down