1717# and s2wrapper: https://github.com/bfshi/scaling_on_scales
1818
1919import math
20- from typing import List , Optional , Tuple
20+ from typing import List , Optional , Tuple , Callable , Dict , Any , Union
2121
2222import torch
2323import torch .nn .functional as F
3030from tensorrt_llm .logger import logger
3131
3232
33+ def _get_active_multimodal_params (
34+ multimodal_params : List [MultimodalParams ],
35+ ) -> List [MultimodalParams ]:
36+ """
37+ Get active multimodal params that need encoder processing for chunk prefill.
38+ """
39+ params_to_run = []
40+
41+ for param in multimodal_params :
42+ # Skip if no multimodal content
43+ if not param .has_content ():
44+ continue
45+
46+ # Check if embeddings are already cached
47+ if (param .multimodal_data and
48+ "multimodal_embedding" in param .multimodal_data and
49+ param .multimodal_data ["multimodal_embedding" ] is not None ):
50+ logger .debug (f"Skipping encoder forward for param with cached multimodal_embedding" )
51+ continue
52+
53+ # This param needs encoder processing
54+ params_to_run .append (param )
55+
56+ return params_to_run
57+
58+
59+ def _cache_multimodal_embeddings (
60+ multimodal_params : List [MultimodalParams ],
61+ embeddings : List [torch .Tensor ],
62+ ) -> None :
63+ """
64+ Cache computed multimodal embeddings back to multimodal_data to avoid recomputation.
65+ Uses torch.split for efficient tensor splitting without manual indexing.
66+ """
67+ # TODO: support multiple multimodal modalities per request
68+ assert len (embeddings ) == 1 , "Currently only support single mm_embeds (single modality) per request"
69+ mm_embed = embeddings [0 ]
70+
71+ # Collect embedding lengths for each parameter
72+ embed_lengths = [param .multimodal_runtime .total_mm_tokens for param in multimodal_params ]
73+
74+ # Validate total length matches
75+ total_expected = sum (embed_lengths )
76+ assert len (mm_embed ) == total_expected , \
77+ f"Number of mm_embeds ({ len (mm_embed )} ) does not match expected total ({ total_expected } )"
78+
79+ # Use torch.split for efficient tensor splitting
80+ split_embeddings = torch .split (mm_embed , embed_lengths , dim = 0 )
81+
82+ # Cache split embeddings to each parameter
83+ for param , embed_chunk in zip (multimodal_params , split_embeddings ):
84+ param .multimodal_data ["multimodal_embedding" ] = embed_chunk
85+
86+ logger .debug (f"Cached { len (split_embeddings )} multimodal embedding chunks in this iteration" )
87+
88+
89+ def get_multimodal_embeddings (
90+ encoder_forward_fn ,
91+ multimodal_params : List [MultimodalParams ],
92+ ) -> List [torch .Tensor ]:
93+ """
94+ High-level utility to get multimodal embeddings from encoder or cached embeddings.
95+
96+ This function will:
97+ 1. Identify which parameters need encoder processing
98+ 2. Run encoder forward only on uncached parameters
99+ 3. Cache newly computed embeddings (if enabled)
100+ 4. Gather all embeddings for the batch
101+
102+ Args:
103+ encoder_forward_fn: Callable that performs encoder forward pass
104+ Should accept List[MultimodalParams] and return List[torch.Tensor]
105+ multimodal_params: All multimodal parameters in the batch
106+
107+ Returns:
108+ List of multimodal embeddings for all multimodal params in the batch
109+ """
110+ if not multimodal_params :
111+ return []
112+
113+ # Step 1: Find active multimodal params that need encoder processing
114+ active_multimodal_params = _get_active_multimodal_params (
115+ multimodal_params
116+ )
117+
118+ # Step 2: Run encoder forward only on uncached parameters
119+ if active_multimodal_params :
120+ encoder_outputs = encoder_forward_fn (active_multimodal_params )
121+
122+ # TODO: support multiple multimodal modalities per request
123+ if len (encoder_outputs ) > 1 :
124+ return encoder_outputs
125+
126+ # Validate that multimodal_runtime has required attributes for caching
127+ if (not hasattr (active_multimodal_params [0 ], 'multimodal_runtime' ) or
128+ active_multimodal_params [0 ].multimodal_runtime is None or
129+ active_multimodal_params [0 ].multimodal_runtime .total_mm_tokens is None ):
130+ logger .warning ("Multimodal runtime data missing or incomplete - recomputed all embeddings" )
131+ return encoder_outputs
132+
133+ # Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"]
134+ _cache_multimodal_embeddings (
135+ active_multimodal_params , encoder_outputs
136+ )
137+
138+ # Step 4: Gather all embeddings for the batch
139+ all_embeddings = torch .cat ([param .multimodal_data ["multimodal_embedding" ] for param in multimodal_params ], dim = 0 )
140+ return [all_embeddings ]
141+
142+
33143def find_input_mm_embeds (
34144 mm_embeds : List [torch .Tensor ],
35145 multimodal_params : List [MultimodalParams ]) -> List [torch .Tensor ]:
@@ -66,10 +176,6 @@ def find_input_mm_embeds(
66176 return mm_embeds
67177
68178 # Calculate total tokens that need processing (both cached and current chunk)
69- total_unseen_mm_tokens = sum ([
70- param .multimodal_runtime .num_unseen_mm_tokens
71- for param in multimodal_params
72- ])
73179 total_mm_tokens = sum ([
74180 param .multimodal_runtime .num_mm_tokens
75181 for param in multimodal_params
@@ -92,12 +198,11 @@ def find_input_mm_embeds(
92198 slices .append ((current_pos + runtime .num_unseen_mm_tokens ,
93199 current_pos + runtime .num_unseen_mm_tokens + runtime .num_mm_tokens ))
94200 if len (mm_embeds ) == 1 : # pre-concatenated mm_embeds, need global offset
95- current_pos += sum ( runtime .mm_token_lengths )
201+ current_pos += runtime .total_mm_tokens
96202
97203 sliced_mm_embeds = []
98204 if len (mm_embeds ) == 1 :
99- for start , end in slices :
100- sliced_mm_embeds .append (mm_embeds [0 ][start :end ])
205+ sliced_mm_embeds = [mm_embeds [0 ][start :end ] for start , end in slices ]
101206 else : # slice each mm_embeds individually
102207 for i , (start , end ) in enumerate (slices ):
103208 sliced_mm_embeds .append (mm_embeds [i ][start :end ])
0 commit comments