Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/reference/multimodal-feature-support-matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
| Gemma 3 | Yes | Yes | N/A | N/A |
| HyperCLOVA | Yes | Yes | No | No |
| VILA | Yes | No | No | No |
| LLaVA-NeXT | Yes | Yes | Yes | No |
| LLaVA-NeXT | Yes | Yes | Yes | Yes |
| Llama 4 | Yes | Yes | No | No |
| Mistral-Small-3.1 | Yes | Yes | No | No |
| Phi-4-multimodal | Yes | Yes | No | No |
| Qwen2-VL | Yes | Yes | Yes | No |
| Qwen2.5-VL | Yes | Yes | Yes | No |
| Qwen2-VL | Yes | Yes | Yes | Yes |
| Qwen2.5-VL | Yes | Yes | Yes | Yes |
27 changes: 11 additions & 16 deletions tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from ..model_config import ModelConfig
from .modeling_auto import AutoModelForCausalLM
from .modeling_clip import CLIPVisionModel
from .modeling_multimodal_utils import (find_uncached_mm_embeds,
fuse_input_embeds)
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
get_multimodal_embeddings)
from .modeling_utils import (filter_weights, register_auto_model,
register_vision_encoder)

Expand Down Expand Up @@ -462,21 +462,16 @@ def forward(
mm_embeds = []
if len(multimodal_params) > 0:
if not DISAGG:
if multimodal_params[0].multimodal_data.get(
"multimodal_embedding", None) is not None:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
else:
mm_embeds = self.mm_encoder.forward(multimodal_params)
mm_embeds = find_uncached_mm_embeds(
mm_embeds, multimodal_params[:num_context_requests])
mm_embeds = get_multimodal_embeddings(
encoder_forward_fn=self.mm_encoder.forward,
multimodal_params=multimodal_params[:num_context_requests])
else:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
raise NotImplementedError(
"LlavaNextModel does not support disaggregated inference yet. Please unset "
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
)
mm_embeds = find_input_mm_embeds(
mm_embeds, multimodal_params[:num_context_requests])
input_ids, inputs_embeds = fuse_input_embeds(
self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs)

Expand Down
192 changes: 160 additions & 32 deletions tensorrt_llm/_torch/models/modeling_multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,153 @@
from tensorrt_llm.logger import logger


def find_uncached_mm_embeds(
def _get_uncached_multimodal_params(
multimodal_params: List[MultimodalParams], ) -> List[MultimodalParams]:
"""
Get uncached multimodal params that need encoder processing for chunk prefill.
"""
params_to_run = []

for param in multimodal_params:
# Skip if no multimodal content
if not param.has_content():
continue

# Check if embeddings are already cached
if (param.multimodal_data
and "multimodal_embedding" in param.multimodal_data
and param.multimodal_data["multimodal_embedding"] is not None):
logger.debug(
f"Skipping encoder forward for param with cached multimodal_embedding"
)
continue

# This param needs encoder processing
params_to_run.append(param)

return params_to_run


def _cache_multimodal_embeddings(
multimodal_params: List[MultimodalParams],
embeddings: List[torch.Tensor],
) -> None:
"""
Cache computed multimodal embeddings back to multimodal_data to avoid recomputation.
Note this function only caches multimodal embeddings within the current request context,
mostly for chunked prefill. It does not persist embeddings across different requests or sessions.
"""
# TODO: support multiple multimodal modalities per request
assert len(
embeddings
) == 1, "Currently only support single mm_embeds (single modality) per request"
mm_embed = embeddings[0]

# Collect embedding lengths for each parameter
embed_lengths = [
param.multimodal_runtime.total_mm_tokens_in_request -
param.multimodal_runtime.total_special_tokens_in_request
for param in multimodal_params
]

# Validate total length matches
total_expected = sum(embed_lengths)
assert len(mm_embed) == total_expected, \
f"Number of mm_embeds ({len(mm_embed)}) does not match expected total ({total_expected})"

# Use torch.split for efficient tensor splitting
split_embeddings = torch.split(mm_embed, embed_lengths, dim=0)

# Cache split embeddings to each parameter
for param, embed_chunk in zip(multimodal_params, split_embeddings):
param.multimodal_data["multimodal_embedding"] = embed_chunk

logger.debug(
f"Cached {len(split_embeddings)} multimodal embedding chunks in this iteration"
)


def get_multimodal_embeddings(
encoder_forward_fn,
multimodal_params: List[MultimodalParams],
) -> List[torch.Tensor]:
"""
High-level utility to get multimodal embeddings from encoder or cached embeddings.

This function will:
1. Identify which parameters need encoder processing
2. Run encoder forward only on uncached parameters
3. Cache newly computed embeddings (if enabled)
4. Gather all embeddings for the batch

Args:
encoder_forward_fn: Callable that performs encoder forward pass
Should accept List[MultimodalParams] and return List[torch.Tensor]
multimodal_params: All multimodal parameters in the batch

Returns:
List of multimodal embeddings for all multimodal params in the batch
"""
if not multimodal_params:
return []

# Step 1: Find uncached multimodal params that need encoder processing
uncached_multimodal_params = _get_uncached_multimodal_params(
multimodal_params)

# Step 2: Run encoder forward only on uncached parameters
if uncached_multimodal_params:
encoder_outputs = encoder_forward_fn(uncached_multimodal_params)

# TODO: support multiple multimodal modalities per request
if len(encoder_outputs) > 1:
return encoder_outputs

# Validate that multimodal_runtime has required attributes for caching
if (not hasattr(uncached_multimodal_params[0], 'multimodal_runtime')
or uncached_multimodal_params[0].multimodal_runtime is None
or uncached_multimodal_params[0].multimodal_runtime.
total_mm_tokens_in_request is None):
logger.warning(
"Multimodal runtime data missing or incomplete - recomputed all embeddings"
)
return encoder_outputs

# Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"]
_cache_multimodal_embeddings(uncached_multimodal_params,
encoder_outputs)

# Step 4: Gather all embeddings for the batch
all_embeddings = torch.cat([
param.multimodal_data["multimodal_embedding"]
for param in multimodal_params
],
dim=0)
return [all_embeddings]


def find_input_mm_embeds(
mm_embeds: List[torch.Tensor],
multimodal_params: List[MultimodalParams]) -> torch.Tensor:
multimodal_params: List[MultimodalParams]) -> List[torch.Tensor]:
"""
Find the uncached multimodal mm_embeds from multimodal_params for each batch.
Find the multimodal mm_embeds that need processing from multimodal_params for each batch.
Supports both KV cache reuse and chunked prefill scenarios.

Args:
- mm_embeds: List[torch.Tensor]
- multimodal_params: List[MultimodalParams]
- mm_embeds: List[torch.Tensor] - Multimodal embeddings for each batch
- multimodal_params: List[MultimodalParams] - Multimodal parameters with runtime data

Returns:
- sliced_mm_embeds: List[torch.Tensor]
When kv_cache reuse is disabled or model not enabled/support kv_cache reuse, return the full mm_embeds.
- List[torch.Tensor] - Sliced mm_embeds containing only tokens that need processing:
- For KV cache reuse: tokens that are not cached
- For chunked prefill: tokens that are in the current chunk
- For mixed scenarios: both uncached and current chunk tokens
- Empty list if all tokens are cached or beyond current chunk

Note:
- Current implementation assumes chunk prefill is disabled. To support chunk prefill, we might need to slightly modify the logic (see TODO below).
- Supports both individual batching (len(mm_embeds) == len(multimodal_params))
and pre-concatenated batching (len(mm_embeds) == 1)
- Handles chunked prefill by considering chunk boundaries and current chunk tokens
"""
# Current support two batching modes:
# 1. Pre-concatenated mm_embeds for each batch, i.e., len(mm_embeds) == 1
Expand All @@ -56,52 +190,46 @@ def find_uncached_mm_embeds(
# No slicing, return the full mm_embeds
return mm_embeds

total_cached_mm_tokens = sum([
param.multimodal_runtime.num_cached_mm_tokens
# Calculate total tokens that need processing (both cached and current chunk)
total_mm_tokens = sum([
param.multimodal_runtime.num_mm_tokens_in_chunk -
param.multimodal_runtime.num_special_tokens_in_chunk
for param in multimodal_params
])
if total_cached_mm_tokens == 0:
# No cached tokens, return the full mm_embeds
# TODO: support chunk prefill for multimodal, then we need to extract full mm_embeds for each CHUNK
logger.debug(
"No multimodal cached tokens can be reused, return the full mm_embeds"
)
return mm_embeds

if total_cached_mm_tokens == sum([
param.multimodal_runtime.total_mm_tokens
for param in multimodal_params
]):
# All tokens are cached, return empty list
if total_mm_tokens == 0:
# No tokens need processing, return empty list
logger.debug(
"All multimodal tokens cached, skipping vision encoder forward")
"All multimodal tokens are cached or beyond current chunk, skipping vision encoder forward"
)
return []

# Partial caching, return the sliced mm_embeds
if total_mm_tokens == sum(mm_embed.shape[0] for mm_embed in mm_embeds):
return mm_embeds

current_pos = 0
slices = []
for param in multimodal_params:
runtime = param.multimodal_runtime
slices.append((current_pos + runtime.num_cached_mm_tokens,
current_pos + runtime.total_mm_tokens))
local_start_pos = runtime.num_unseen_mm_tokens - runtime.num_unseen_special_tokens
local_end_pos = local_start_pos + runtime.num_mm_tokens_in_chunk - runtime.num_special_tokens_in_chunk
slices.append(
(current_pos + local_start_pos, current_pos + local_end_pos))
if len(mm_embeds
) == 1: # pre-concatenated mm_embeds, need global offset
current_pos += runtime.total_mm_tokens
current_pos += runtime.total_mm_tokens_in_request
current_pos -= runtime.total_special_tokens_in_request

sliced_mm_embeds = []
if len(mm_embeds) == 1:
for start, end in slices:
sliced_mm_embeds.append(mm_embeds[0][start:end])
sliced_mm_embeds = [mm_embeds[0][start:end] for start, end in slices]
else: # slice each mm_embeds individually
for i, (start, end) in enumerate(slices):
sliced_mm_embeds.append(mm_embeds[i][start:end])

if len(mm_embeds) == 1:
sliced_mm_embeds = [torch.cat(sliced_mm_embeds, dim=0)]

logger.debug(
f"Partial caching, return sliced_mm_embeds: {sliced_mm_embeds[0].shape}"
)
return sliced_mm_embeds


Expand Down
21 changes: 11 additions & 10 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig
from .modeling_auto import AutoModelForCausalLM
from .modeling_multimodal_utils import (find_uncached_mm_embeds,
fuse_input_embeds)
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
get_multimodal_embeddings)
from .modeling_utils import register_auto_model, register_vision_encoder

DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
Expand Down Expand Up @@ -280,7 +280,6 @@ def __call__(
mm_processor_kwargs)
if not mm_data:
fused_input_ids = processed_inputs['input_ids']
# Flatten the tensor to get a simple list of integers
return fused_input_ids.flatten().to(torch.int32).tolist(), {}

pixel_values = processed_inputs.get('pixel_values', None)
Expand Down Expand Up @@ -594,17 +593,19 @@ def forward(

if len(multimodal_params) > 0:
if not DISAGG:
mm_embeds = self.mm_encoder.forward(
multimodal_params[:num_context_requests])
mm_embeds = get_multimodal_embeddings(
encoder_forward_fn=self.mm_encoder.forward,
multimodal_params=multimodal_params[:num_context_requests])
else:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
raise NotImplementedError(
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
)
mrope_config = self._parse_and_concat_mrope_config(
multimodal_params, num_context_requests,
num_generation_requests)
mm_embeds = find_uncached_mm_embeds(

mm_embeds = find_input_mm_embeds(
mm_embeds, multimodal_params[:num_context_requests])

if 'mrope_position_deltas' in kwargs:
Expand Down
9 changes: 4 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,11 +1325,13 @@ def _prepare_tp_inputs(
num_cached_tokens_per_seq.append(past_seen_token_num)

# Multimodal
# TODO: enable chunk prefill for multimodal (maybe need to pass prompt_tokens to MultimodalRuntimeData)
py_multimodal_runtime = MultimodalRuntimeData(
mm_token_lengths=request.multimodal_lengths,
mm_token_positions=request.multimodal_positions,
num_cached_tokens=past_seen_token_num
past_seen_token_num=past_seen_token_num,
chunk_end_pos=end_compute,
special_token_offsets=request.py_multimodal_data.get(
'special_token_offsets', []),
) if request.multimodal_hashes is not None else None

multimodal_params = MultimodalParams(
Expand All @@ -1348,9 +1350,6 @@ def _prepare_tp_inputs(

if len(multimodal_params_list) > 0:
# discard the text token indices as it only includes context tokens at this moment
print(
f"len multimodal_params_list: {len(multimodal_params_list)} from model_engine"
)
_, mm_token_indices = self._prepare_multimodal_indices(input_ids)
else:
mm_token_indices = None
Expand Down
Loading