Skip to content

Commit a6befdc

Browse files
committed
Address comments
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 3caea5f commit a6befdc

File tree

7 files changed

+80
-67
lines changed

7 files changed

+80
-67
lines changed

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
from tensorrt_llm.logger import logger
3131

3232

33-
def _get_active_multimodal_params(
33+
def _get_uncached_multimodal_params(
3434
multimodal_params: List[MultimodalParams], ) -> List[MultimodalParams]:
3535
"""
36-
Get active multimodal params that need encoder processing for chunk prefill.
36+
Get uncached multimodal params that need encoder processing for chunk prefill.
3737
"""
3838
params_to_run = []
3939

@@ -63,7 +63,8 @@ def _cache_multimodal_embeddings(
6363
) -> None:
6464
"""
6565
Cache computed multimodal embeddings back to multimodal_data to avoid recomputation.
66-
Uses torch.split for efficient tensor splitting without manual indexing.
66+
Note this function only caches multimodal embeddings within the current request context,
67+
mostly for chunked prefill. It does not persist embeddings across different requests or sessions.
6768
"""
6869
# TODO: support multiple multimodal modalities per request
6970
assert len(
@@ -73,7 +74,8 @@ def _cache_multimodal_embeddings(
7374

7475
# Collect embedding lengths for each parameter
7576
embed_lengths = [
76-
param.multimodal_runtime.total_mm_tokens for param in multimodal_params
77+
param.multimodal_runtime.total_mm_tokens_in_request
78+
for param in multimodal_params
7779
]
7880

7981
# Validate total length matches
@@ -117,29 +119,31 @@ def get_multimodal_embeddings(
117119
if not multimodal_params:
118120
return []
119121

120-
# Step 1: Find active multimodal params that need encoder processing
121-
active_multimodal_params = _get_active_multimodal_params(multimodal_params)
122+
# Step 1: Find uncached multimodal params that need encoder processing
123+
uncached_multimodal_params = _get_uncached_multimodal_params(
124+
multimodal_params)
122125

123126
# Step 2: Run encoder forward only on uncached parameters
124-
if active_multimodal_params:
125-
encoder_outputs = encoder_forward_fn(active_multimodal_params)
127+
if uncached_multimodal_params:
128+
encoder_outputs = encoder_forward_fn(uncached_multimodal_params)
126129

127130
# TODO: support multiple multimodal modalities per request
128131
if len(encoder_outputs) > 1:
129132
return encoder_outputs
130133

131134
# Validate that multimodal_runtime has required attributes for caching
132-
if (not hasattr(active_multimodal_params[0], 'multimodal_runtime')
133-
or active_multimodal_params[0].multimodal_runtime is None or
134-
active_multimodal_params[0].multimodal_runtime.total_mm_tokens
135-
is None):
135+
if (not hasattr(uncached_multimodal_params[0], 'multimodal_runtime')
136+
or uncached_multimodal_params[0].multimodal_runtime is None
137+
or uncached_multimodal_params[0].multimodal_runtime.
138+
total_mm_tokens_in_request is None):
136139
logger.warning(
137140
"Multimodal runtime data missing or incomplete - recomputed all embeddings"
138141
)
139142
return encoder_outputs
140143

141144
# Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"]
142-
_cache_multimodal_embeddings(active_multimodal_params, encoder_outputs)
145+
_cache_multimodal_embeddings(uncached_multimodal_params,
146+
encoder_outputs)
143147

144148
# Step 4: Gather all embeddings for the batch
145149
all_embeddings = torch.cat([
@@ -186,8 +190,10 @@ def find_input_mm_embeds(
186190
return mm_embeds
187191

188192
# Calculate total tokens that need processing (both cached and current chunk)
189-
total_mm_tokens = sum(
190-
[param.multimodal_runtime.num_mm_tokens for param in multimodal_params])
193+
total_mm_tokens = sum([
194+
param.multimodal_runtime.num_mm_tokens_in_chunk
195+
for param in multimodal_params
196+
])
191197

192198
if total_mm_tokens == 0:
193199
# No tokens need processing, return empty list
@@ -203,11 +209,12 @@ def find_input_mm_embeds(
203209
slices = []
204210
for param in multimodal_params:
205211
runtime = param.multimodal_runtime
206-
slices.append((current_pos + runtime.num_unseen_mm_tokens, current_pos +
207-
runtime.num_unseen_mm_tokens + runtime.num_mm_tokens))
212+
slices.append(
213+
(current_pos + runtime.num_unseen_mm_tokens, current_pos +
214+
runtime.num_unseen_mm_tokens + runtime.num_mm_tokens_in_chunk))
208215
if len(mm_embeds
209216
) == 1: # pre-concatenated mm_embeds, need global offset
210-
current_pos += runtime.total_mm_tokens
217+
current_pos += runtime.total_mm_tokens_in_request
211218

212219
sliced_mm_embeds = []
213220
if len(mm_embeds) == 1:

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __call__(
309309
mm_processor_kwargs)
310310
if not mm_data:
311311
fused_input_ids = processed_inputs['input_ids']
312-
return fused_input_ids.to(torch.int32).tolist(), {}
312+
return fused_input_ids.flatten().to(torch.int32).tolist(), {}
313313

314314
pixel_values = processed_inputs.get('pixel_values', None)
315315
pixel_values_videos = processed_inputs.get('pixel_values_videos', None)
@@ -619,7 +619,6 @@ def forward(
619619
encoder_forward_fn=self.mm_encoder.forward,
620620
multimodal_params=multimodal_params[:num_context_requests])
621621
else:
622-
# TODO: this is a dead path for now
623622
mm_embeds = [
624623
multimodal_param.multimodal_data["multimodal_embedding"]
625624
for multimodal_param in multimodal_params

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,6 @@ def _prepare_tp_inputs(
12411241
num_cached_tokens_per_seq.append(past_seen_token_num)
12421242

12431243
# Multimodal
1244-
# TODO: enable chunk prefill for multimodal (maybe need to pass prompt_tokens to MultimodalRuntimeData)
12451244
py_multimodal_runtime = MultimodalRuntimeData(
12461245
mm_token_lengths=request.multimodal_lengths,
12471246
mm_token_positions=request.multimodal_positions,

tensorrt_llm/inputs/multimodal.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,24 @@ class MultimodalRuntimeData:
9595
mm_token_positions: Starting positions of each multimodal token chunk
9696
chunk_end_pos: End position of the current chunk for chunked prefill
9797
num_unseen_mm_tokens: Number of multimodal tokens that are cached (computed)
98-
num_mm_tokens: Number of multimodal tokens in the current chunk (computed)
99-
total_mm_tokens: Total number of multimodal tokens in the request sequence (computed)
98+
num_mm_tokens_in_chunk: Number of multimodal tokens in the current chunk (computed)
99+
total_mm_tokens_in_request: Total number of multimodal tokens in the request sequence (computed)
100100
"""
101101
past_seen_token_num: int
102102
mm_token_lengths: List[int]
103103
mm_token_positions: List[int]
104104
chunk_end_pos: int
105105

106106
num_unseen_mm_tokens: Optional[int] = None
107-
num_mm_tokens: Optional[int] = None
108-
total_mm_tokens: Optional[int] = None
107+
num_mm_tokens_in_chunk: Optional[int] = None
108+
total_mm_tokens_in_request: Optional[int] = None
109109

110110
# TODO: fine-grained control of encoder runner/cache to each mm_item
111111

112112
def __post_init__(self):
113113
# Validate input data
114-
if self.total_mm_tokens is None:
115-
self.total_mm_tokens = sum(self.mm_token_lengths)
114+
if self.total_mm_tokens_in_request is None:
115+
self.total_mm_tokens_in_request = sum(self.mm_token_lengths)
116116
if len(self.mm_token_positions) != len(self.mm_token_lengths):
117117
raise ValueError(
118118
f"mm_token_positions ({len(self.mm_token_positions)}) and mm_token_lengths ({len(self.mm_token_lengths)}) must have the same length"
@@ -133,34 +133,36 @@ def __post_init__(self):
133133
f"All mm_token_positions must be non-negative, got {self.mm_token_positions}"
134134
)
135135

136-
if self.num_unseen_mm_tokens is None or self.num_mm_tokens is None:
136+
if self.num_unseen_mm_tokens is None or self.num_mm_tokens_in_chunk is None:
137137
# Compute cached multimodal tokens based on positions and cached tokens
138138
self.num_unseen_mm_tokens = 0
139-
self.num_mm_tokens = 0
139+
self.num_mm_tokens_in_chunk = 0
140+
remainder = 0
140141
for pos, length in zip(self.mm_token_positions,
141142
self.mm_token_lengths):
142143
if pos + length <= self.past_seen_token_num:
143144
self.num_unseen_mm_tokens += length
144145
elif pos < self.past_seen_token_num:
145146
# Partial overlap - only count the cached portion
146147
self.num_unseen_mm_tokens += self.past_seen_token_num - pos
147-
if pos + length > self.chunk_end_pos:
148-
self.num_mm_tokens += self.chunk_end_pos - self.past_seen_token_num
149-
else:
150-
self.num_mm_tokens += pos + length - self.past_seen_token_num
148+
self.num_mm_tokens_in_chunk += min(
149+
self.chunk_end_pos,
150+
pos + length) - self.past_seen_token_num
151151
else:
152152
if pos + length > self.chunk_end_pos:
153153
# Partial overlap - only count the cached portion
154154
if pos < self.chunk_end_pos:
155-
self.num_mm_tokens += self.chunk_end_pos - pos
155+
self.num_mm_tokens_in_chunk += self.chunk_end_pos - pos
156+
else:
157+
remainder += length
156158
else:
157159
# Full overlap - count the entire mm item chunk
158-
self.num_mm_tokens += length
160+
self.num_mm_tokens_in_chunk += length
159161

160-
if self.num_unseen_mm_tokens + self.num_mm_tokens > sum(
162+
if self.num_unseen_mm_tokens + self.num_mm_tokens_in_chunk + remainder > sum(
161163
self.mm_token_lengths):
162164
raise ValueError(
163-
f"num_unseen_mm_tokens ({self.num_unseen_mm_tokens}) + num_mm_tokens ({self.num_mm_tokens}) must be less than or equal to sum of mm_token_lengths ({sum(self.mm_token_lengths)})"
165+
f"num_unseen_mm_tokens ({self.num_unseen_mm_tokens}) + num_mm_tokens_in_chunk ({self.num_mm_tokens_in_chunk}) + remainder ({remainder}) must be less than or equal to sum of mm_token_lengths ({sum(self.mm_token_lengths)})"
164166
)
165167

166168

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ l0_h100:
9292
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
9393
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
9494
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
95-
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
95+
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-False-image-True]
9696
- condition:
9797
ranges:
9898
system_gpu_count:
@@ -217,8 +217,8 @@ l0_h100:
217217
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
218218
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
219219
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
220-
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
221-
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
220+
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-False-image-True]
221+
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-False-mixture_text_image-True]
222222
- condition:
223223
ranges:
224224
system_gpu_count:

tests/integration/test_lists/test-db/l0_l40s.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ l0_l40s:
1919
- unittest/_torch/modeling -k "modeling_vila"
2020
- unittest/_torch/modeling -k "modeling_siglip"
2121
- test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
22-
- test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False]
23-
- test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-video-False]
24-
- test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False]
25-
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-image-False]
26-
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-video-False]
27-
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image-False]
28-
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image-True]
29-
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-False]
30-
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True]
22+
- test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-False-image-False]
23+
- test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-False-video-False]
24+
- test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-False-image-False]
25+
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-True-image-True]
26+
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-False-video-False]
27+
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-True-image-False]
28+
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-False-image-True]
29+
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-False-video-False]
30+
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-False-video-True]
3131
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
3232
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
3333
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]

0 commit comments

Comments
 (0)