Skip to content

Commit 00d6692

Browse files
committed
Enable chunked prefill for qwen2vl
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 250f92a commit 00d6692

File tree

4 files changed

+383
-17
lines changed

4 files changed

+383
-17
lines changed

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# and s2wrapper: https://github.com/bfshi/scaling_on_scales
1818

1919
import math
20-
from typing import List, Optional, Tuple
20+
from typing import List, Optional, Tuple, Callable, Dict, Any, Union
2121

2222
import torch
2323
import torch.nn.functional as F
@@ -30,6 +30,116 @@
3030
from tensorrt_llm.logger import logger
3131

3232

33+
def _get_active_multimodal_params(
34+
multimodal_params: List[MultimodalParams],
35+
) -> List[MultimodalParams]:
36+
"""
37+
Get active multimodal params that need encoder processing for chunk prefill.
38+
"""
39+
params_to_run = []
40+
41+
for param in multimodal_params:
42+
# Skip if no multimodal content
43+
if not param.has_content():
44+
continue
45+
46+
# Check if embeddings are already cached
47+
if (param.multimodal_data and
48+
"multimodal_embedding" in param.multimodal_data and
49+
param.multimodal_data["multimodal_embedding"] is not None):
50+
logger.debug(f"Skipping encoder forward for param with cached multimodal_embedding")
51+
continue
52+
53+
# This param needs encoder processing
54+
params_to_run.append(param)
55+
56+
return params_to_run
57+
58+
59+
def _cache_multimodal_embeddings(
60+
multimodal_params: List[MultimodalParams],
61+
embeddings: List[torch.Tensor],
62+
) -> None:
63+
"""
64+
Cache computed multimodal embeddings back to multimodal_data to avoid recomputation.
65+
Uses torch.split for efficient tensor splitting without manual indexing.
66+
"""
67+
# TODO: support multiple multimodal modalities per request
68+
assert len(embeddings) == 1, "Currently only support single mm_embeds (single modality) per request"
69+
mm_embed = embeddings[0]
70+
71+
# Collect embedding lengths for each parameter
72+
embed_lengths = [param.multimodal_runtime.total_mm_tokens for param in multimodal_params]
73+
74+
# Validate total length matches
75+
total_expected = sum(embed_lengths)
76+
assert len(mm_embed) == total_expected, \
77+
f"Number of mm_embeds ({len(mm_embed)}) does not match expected total ({total_expected})"
78+
79+
# Use torch.split for efficient tensor splitting
80+
split_embeddings = torch.split(mm_embed, embed_lengths, dim=0)
81+
82+
# Cache split embeddings to each parameter
83+
for param, embed_chunk in zip(multimodal_params, split_embeddings):
84+
param.multimodal_data["multimodal_embedding"] = embed_chunk
85+
86+
logger.debug(f"Cached {len(split_embeddings)} multimodal embedding chunks in this iteration")
87+
88+
89+
def get_multimodal_embeddings(
90+
encoder_forward_fn,
91+
multimodal_params: List[MultimodalParams],
92+
) -> List[torch.Tensor]:
93+
"""
94+
High-level utility to get multimodal embeddings from encoder or cached embeddings.
95+
96+
This function will:
97+
1. Identify which parameters need encoder processing
98+
2. Run encoder forward only on uncached parameters
99+
3. Cache newly computed embeddings (if enabled)
100+
4. Gather all embeddings for the batch
101+
102+
Args:
103+
encoder_forward_fn: Callable that performs encoder forward pass
104+
Should accept List[MultimodalParams] and return List[torch.Tensor]
105+
multimodal_params: All multimodal parameters in the batch
106+
107+
Returns:
108+
List of multimodal embeddings for all multimodal params in the batch
109+
"""
110+
if not multimodal_params:
111+
return []
112+
113+
# Step 1: Find active multimodal params that need encoder processing
114+
active_multimodal_params = _get_active_multimodal_params(
115+
multimodal_params
116+
)
117+
118+
# Step 2: Run encoder forward only on uncached parameters
119+
if active_multimodal_params:
120+
encoder_outputs = encoder_forward_fn(active_multimodal_params)
121+
122+
# TODO: support multiple multimodal modalities per request
123+
if len(encoder_outputs) > 1:
124+
return encoder_outputs
125+
126+
# Validate that multimodal_runtime has required attributes for caching
127+
if (not hasattr(active_multimodal_params[0], 'multimodal_runtime') or
128+
active_multimodal_params[0].multimodal_runtime is None or
129+
active_multimodal_params[0].multimodal_runtime.total_mm_tokens is None):
130+
logger.warning("Multimodal runtime data missing or incomplete - recomputed all embeddings")
131+
return encoder_outputs
132+
133+
# Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"]
134+
_cache_multimodal_embeddings(
135+
active_multimodal_params, encoder_outputs
136+
)
137+
138+
# Step 4: Gather all embeddings for the batch
139+
all_embeddings = torch.cat([param.multimodal_data["multimodal_embedding"] for param in multimodal_params], dim=0)
140+
return [all_embeddings]
141+
142+
33143
def find_input_mm_embeds(
34144
mm_embeds: List[torch.Tensor],
35145
multimodal_params: List[MultimodalParams]) -> List[torch.Tensor]:
@@ -66,10 +176,6 @@ def find_input_mm_embeds(
66176
return mm_embeds
67177

68178
# Calculate total tokens that need processing (both cached and current chunk)
69-
total_unseen_mm_tokens = sum([
70-
param.multimodal_runtime.num_unseen_mm_tokens
71-
for param in multimodal_params
72-
])
73179
total_mm_tokens = sum([
74180
param.multimodal_runtime.num_mm_tokens
75181
for param in multimodal_params
@@ -92,12 +198,11 @@ def find_input_mm_embeds(
92198
slices.append((current_pos + runtime.num_unseen_mm_tokens,
93199
current_pos + runtime.num_unseen_mm_tokens + runtime.num_mm_tokens))
94200
if len(mm_embeds) == 1: # pre-concatenated mm_embeds, need global offset
95-
current_pos += sum(runtime.mm_token_lengths)
201+
current_pos += runtime.total_mm_tokens
96202

97203
sliced_mm_embeds = []
98204
if len(mm_embeds) == 1:
99-
for start, end in slices:
100-
sliced_mm_embeds.append(mm_embeds[0][start:end])
205+
sliced_mm_embeds = [mm_embeds[0][start:end] for start, end in slices]
101206
else: # slice each mm_embeds individually
102207
for i, (start, end) in enumerate(slices):
103208
sliced_mm_embeds.append(mm_embeds[i][start:end])

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
from ..attention_backend import AttentionMetadata
2020
from ..model_config import ModelConfig
2121
from .modeling_auto import AutoModelForCausalLM
22-
from .modeling_multimodal_utils import (find_input_mm_embeds,
23-
fuse_input_embeds)
22+
from .modeling_multimodal_utils import (
23+
find_input_mm_embeds,
24+
fuse_input_embeds,
25+
get_multimodal_embeddings
26+
)
2427
from .modeling_utils import register_auto_model
2528

2629
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
@@ -613,11 +616,9 @@ def forward(
613616

614617
if len(multimodal_params) > 0:
615618
if not DISAGG:
616-
#mm_embeds = self.mm_encoder.forward(
617-
# multimodal_params[:num_context_requests])
618-
# Get the full mm embeds (from cache or compute)
619-
mm_embeds = self._get_or_compute_mm_embeds(
620-
multimodal_params[:num_context_requests]
619+
mm_embeds = get_multimodal_embeddings(
620+
encoder_forward_fn=self.mm_encoder.forward,
621+
multimodal_params=multimodal_params[:num_context_requests]
621622
)
622623
else:
623624
# TODO: this is a dead path for now

tensorrt_llm/inputs/multimodal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class MultimodalRuntimeData:
9696
chunk_end_pos: End position of the current chunk for chunked prefill
9797
num_unseen_mm_tokens: Number of multimodal tokens that are cached (computed)
9898
num_mm_tokens: Number of multimodal tokens in the current chunk (computed)
99+
total_mm_tokens: Total number of multimodal tokens in the request sequence (computed)
99100
"""
100101
past_seen_token_num: int # == num_cached_tokens
101102
mm_token_lengths: List[int]
@@ -104,10 +105,13 @@ class MultimodalRuntimeData:
104105

105106
num_unseen_mm_tokens: Optional[int] = None
106107
num_mm_tokens: Optional[int] = None
108+
total_mm_tokens: Optional[int] = None
107109
# TODO: fine-grained control of encoder runner/cache to each mm_item
108110

109111
def __post_init__(self):
110112
# Validate input data
113+
if self.total_mm_tokens is None:
114+
self.total_mm_tokens = sum(self.mm_token_lengths)
111115
if len(self.mm_token_positions) != len(self.mm_token_lengths):
112116
raise ValueError(
113117
f"mm_token_positions ({len(self.mm_token_positions)}) and mm_token_lengths ({len(self.mm_token_lengths)}) must have the same length"

0 commit comments

Comments
 (0)