Skip to content

Commit 183d3f0

Browse files
committed
Format and add e2e tests
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 00d6692 commit 183d3f0

File tree

5 files changed

+231
-150
lines changed

5 files changed

+231
-150
lines changed

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# and s2wrapper: https://github.com/bfshi/scaling_on_scales
1818

1919
import math
20-
from typing import List, Optional, Tuple, Callable, Dict, Any, Union
20+
from typing import List, Optional, Tuple
2121

2222
import torch
2323
import torch.nn.functional as F
@@ -31,8 +31,7 @@
3131

3232

3333
def _get_active_multimodal_params(
34-
multimodal_params: List[MultimodalParams],
35-
) -> List[MultimodalParams]:
34+
multimodal_params: List[MultimodalParams], ) -> List[MultimodalParams]:
3635
"""
3736
Get active multimodal params that need encoder processing for chunk prefill.
3837
"""
@@ -44,10 +43,12 @@ def _get_active_multimodal_params(
4443
continue
4544

4645
# Check if embeddings are already cached
47-
if (param.multimodal_data and
48-
"multimodal_embedding" in param.multimodal_data and
49-
param.multimodal_data["multimodal_embedding"] is not None):
50-
logger.debug(f"Skipping encoder forward for param with cached multimodal_embedding")
46+
if (param.multimodal_data
47+
and "multimodal_embedding" in param.multimodal_data
48+
and param.multimodal_data["multimodal_embedding"] is not None):
49+
logger.debug(
50+
f"Skipping encoder forward for param with cached multimodal_embedding"
51+
)
5152
continue
5253

5354
# This param needs encoder processing
@@ -65,11 +66,15 @@ def _cache_multimodal_embeddings(
6566
Uses torch.split for efficient tensor splitting without manual indexing.
6667
"""
6768
# TODO: support multiple multimodal modalities per request
68-
assert len(embeddings) == 1, "Currently only support single mm_embeds (single modality) per request"
69+
assert len(
70+
embeddings
71+
) == 1, "Currently only support single mm_embeds (single modality) per request"
6972
mm_embed = embeddings[0]
7073

7174
# Collect embedding lengths for each parameter
72-
embed_lengths = [param.multimodal_runtime.total_mm_tokens for param in multimodal_params]
75+
embed_lengths = [
76+
param.multimodal_runtime.total_mm_tokens for param in multimodal_params
77+
]
7378

7479
# Validate total length matches
7580
total_expected = sum(embed_lengths)
@@ -83,7 +88,9 @@ def _cache_multimodal_embeddings(
8388
for param, embed_chunk in zip(multimodal_params, split_embeddings):
8489
param.multimodal_data["multimodal_embedding"] = embed_chunk
8590

86-
logger.debug(f"Cached {len(split_embeddings)} multimodal embedding chunks in this iteration")
91+
logger.debug(
92+
f"Cached {len(split_embeddings)} multimodal embedding chunks in this iteration"
93+
)
8794

8895

8996
def get_multimodal_embeddings(
@@ -111,9 +118,7 @@ def get_multimodal_embeddings(
111118
return []
112119

113120
# Step 1: Find active multimodal params that need encoder processing
114-
active_multimodal_params = _get_active_multimodal_params(
115-
multimodal_params
116-
)
121+
active_multimodal_params = _get_active_multimodal_params(multimodal_params)
117122

118123
# Step 2: Run encoder forward only on uncached parameters
119124
if active_multimodal_params:
@@ -124,19 +129,24 @@ def get_multimodal_embeddings(
124129
return encoder_outputs
125130

126131
# Validate that multimodal_runtime has required attributes for caching
127-
if (not hasattr(active_multimodal_params[0], 'multimodal_runtime') or
128-
active_multimodal_params[0].multimodal_runtime is None or
129-
active_multimodal_params[0].multimodal_runtime.total_mm_tokens is None):
130-
logger.warning("Multimodal runtime data missing or incomplete - recomputed all embeddings")
132+
if (not hasattr(active_multimodal_params[0], 'multimodal_runtime')
133+
or active_multimodal_params[0].multimodal_runtime is None or
134+
active_multimodal_params[0].multimodal_runtime.total_mm_tokens
135+
is None):
136+
logger.warning(
137+
"Multimodal runtime data missing or incomplete - recomputed all embeddings"
138+
)
131139
return encoder_outputs
132140

133141
# Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"]
134-
_cache_multimodal_embeddings(
135-
active_multimodal_params, encoder_outputs
136-
)
142+
_cache_multimodal_embeddings(active_multimodal_params, encoder_outputs)
137143

138144
# Step 4: Gather all embeddings for the batch
139-
all_embeddings = torch.cat([param.multimodal_data["multimodal_embedding"] for param in multimodal_params], dim=0)
145+
all_embeddings = torch.cat([
146+
param.multimodal_data["multimodal_embedding"]
147+
for param in multimodal_params
148+
],
149+
dim=0)
140150
return [all_embeddings]
141151

142152

@@ -176,28 +186,27 @@ def find_input_mm_embeds(
176186
return mm_embeds
177187

178188
# Calculate total tokens that need processing (both cached and current chunk)
179-
total_mm_tokens = sum([
180-
param.multimodal_runtime.num_mm_tokens
181-
for param in multimodal_params
182-
])
189+
total_mm_tokens = sum(
190+
[param.multimodal_runtime.num_mm_tokens for param in multimodal_params])
183191

184192
if total_mm_tokens == 0:
185193
# No tokens need processing, return empty list
186194
logger.debug(
187-
"All multimodal tokens are cached or beyond current chunk, skipping vision encoder forward")
195+
"All multimodal tokens are cached or beyond current chunk, skipping vision encoder forward"
196+
)
188197
return []
189198

190199
if total_mm_tokens == sum(mm_embed.shape[0] for mm_embed in mm_embeds):
191200
return mm_embeds
192201

193-
194202
current_pos = 0
195203
slices = []
196204
for param in multimodal_params:
197205
runtime = param.multimodal_runtime
198-
slices.append((current_pos + runtime.num_unseen_mm_tokens,
199-
current_pos + runtime.num_unseen_mm_tokens + runtime.num_mm_tokens))
200-
if len(mm_embeds) == 1: # pre-concatenated mm_embeds, need global offset
206+
slices.append((current_pos + runtime.num_unseen_mm_tokens, current_pos +
207+
runtime.num_unseen_mm_tokens + runtime.num_mm_tokens))
208+
if len(mm_embeds
209+
) == 1: # pre-concatenated mm_embeds, need global offset
201210
current_pos += runtime.total_mm_tokens
202211

203212
sliced_mm_embeds = []

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@
1919
from ..attention_backend import AttentionMetadata
2020
from ..model_config import ModelConfig
2121
from .modeling_auto import AutoModelForCausalLM
22-
from .modeling_multimodal_utils import (
23-
find_input_mm_embeds,
24-
fuse_input_embeds,
25-
get_multimodal_embeddings
26-
)
22+
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
23+
get_multimodal_embeddings)
2724
from .modeling_utils import register_auto_model
2825

2926
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
@@ -618,8 +615,7 @@ def forward(
618615
if not DISAGG:
619616
mm_embeds = get_multimodal_embeddings(
620617
encoder_forward_fn=self.mm_encoder.forward,
621-
multimodal_params=multimodal_params[:num_context_requests]
622-
)
618+
multimodal_params=multimodal_params[:num_context_requests])
623619
else:
624620
# TODO: this is a dead path for now
625621
mm_embeds = [
@@ -630,7 +626,8 @@ def forward(
630626
multimodal_params, num_context_requests,
631627
num_generation_requests)
632628

633-
mm_embeds = find_input_mm_embeds(mm_embeds, multimodal_params[:num_context_requests])
629+
mm_embeds = find_input_mm_embeds(
630+
mm_embeds, multimodal_params[:num_context_requests])
634631

635632
if 'mrope_position_deltas' in kwargs:
636633
mrope_config['mrope_position_deltas'] = kwargs[

tensorrt_llm/inputs/multimodal.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ class MultimodalRuntimeData:
9898
num_mm_tokens: Number of multimodal tokens in the current chunk (computed)
9999
total_mm_tokens: Total number of multimodal tokens in the request sequence (computed)
100100
"""
101-
past_seen_token_num: int # == num_cached_tokens
101+
past_seen_token_num: int
102102
mm_token_lengths: List[int]
103103
mm_token_positions: List[int]
104-
chunk_end_pos: int # == end_pos
104+
chunk_end_pos: int
105105

106106
num_unseen_mm_tokens: Optional[int] = None
107107
num_mm_tokens: Optional[int] = None
108108
total_mm_tokens: Optional[int] = None
109+
109110
# TODO: fine-grained control of encoder runner/cache to each mm_item
110111

111112
def __post_init__(self):
@@ -156,7 +157,8 @@ def __post_init__(self):
156157
# Full overlap - count the entire mm item chunk
157158
self.num_mm_tokens += length
158159

159-
if self.num_unseen_mm_tokens + self.num_mm_tokens > sum(self.mm_token_lengths):
160+
if self.num_unseen_mm_tokens + self.num_mm_tokens > sum(
161+
self.mm_token_lengths):
160162
raise ValueError(
161163
f"num_unseen_mm_tokens ({self.num_unseen_mm_tokens}) + num_mm_tokens ({self.num_mm_tokens}) must be less than or equal to sum of mm_token_lengths ({sum(self.mm_token_lengths)})"
162164
)

tests/integration/defs/test_e2e.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2121,6 +2121,7 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
21212121

21222122
@pytest.mark.parametrize("use_cuda_graph", [False, True])
21232123
@pytest.mark.parametrize("modality", ["image", "video", "mixture_text_image"])
2124+
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
21242125
@pytest.mark.parametrize("model_name,model_path", [
21252126
("NVILA-8B-FP16", "vila/NVILA-8B"),
21262127
("NVILA-15B-FP16", "NVILA-15B"),
@@ -2135,9 +2136,16 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
21352136
marks=pytest.mark.skip_less_device_memory(80000)),
21362137
])
21372138
def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
2138-
modality, use_cuda_graph):
2139+
modality, use_cuda_graph,
2140+
enable_chunked_prefill):
21392141
# NOTE: individual tests need to be enabled in
21402142
# tests/integration/test_lists/qa/examples_test_list.txt
2143+
if model_name not in ["qwen2-vl-7b-instruct", "qwen2.5-vl-7b-instruct"
2144+
] and enable_chunked_prefill:
2145+
pytest.skip(
2146+
"Only Qwen2-VL and Qwen2-5-VL support chunked prefill for now")
2147+
if modality != "image" and enable_chunked_prefill:
2148+
pytest.skip("Chunked prefill is only supported for image modality")
21412149

21422150
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
21432151
test_data_root = Path(
@@ -2262,6 +2270,11 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
22622270
if model_name in ["qwen2-vl-7b-instruct", "qwen2.5-vl-7b-instruct"
22632271
] and modality == "video":
22642272
cmd.append("--max_num_tokens=16384")
2273+
else:
2274+
if enable_chunked_prefill:
2275+
cmd.append("--enable_chunked_prefill")
2276+
cmd.append("--max_num_tokens=256")
2277+
22652278
if use_cuda_graph:
22662279
cmd.append("--use_cuda_graph")
22672280
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.

0 commit comments

Comments
 (0)