diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index 229c30e40c9..ceb6b01a9ce 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -100,8 +100,10 @@ def __call__( "pixel_values"] input_ids = preprocess_outputs[0]["mm_processor_kwargs"]["input_ids"] mm_features = self._process(pixel_values) + multimodal_data = {} + multimodal_data["multimodal_embedding"] = mm_features return input_ids[0].to(torch.int32).tolist(), { - "mm_embedding": mm_features + "multimodal_data": multimodal_data } @@ -161,7 +163,11 @@ def forward( f"[Gemma3Model::forward]{num_context_requests=}, {num_generation_requests=}" ) - mm_embed = kwargs.get("multi_modal_data", []) + multimodal_params = kwargs.get("multimodal_params", []) + mm_embed = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] assert mm_embed == [] or len( mm_embed ) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests" diff --git a/tensorrt_llm/_torch/models/modeling_hyperclovax.py b/tensorrt_llm/_torch/models/modeling_hyperclovax.py index 7134ceeecb7..9f37759ba03 100644 --- a/tensorrt_llm/_torch/models/modeling_hyperclovax.py +++ b/tensorrt_llm/_torch/models/modeling_hyperclovax.py @@ -1,16 +1,19 @@ import copy import math import os +from functools import partial from itertools import chain from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn -import transformers -from PIL import Image -from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig, - PreTrainedModel) +from einops import rearrange +from transformers import (AutoConfig, AutoModel, AutoProcessor, AutoTokenizer, + PretrainedConfig, PreTrainedModel) +from transformers.modeling_utils import load_sharded_checkpoint +from transformers.models.auto import CONFIG_MAPPING + +from tensorrt_llm.inputs.multimodal import MultimodalParams from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, register_input_processor) @@ -20,6 +23,7 @@ from ..model_config import ModelConfig from .modeling_auto import AutoModelForCausalLM from .modeling_multimodal_utils import fuse_input_embeds +from .modeling_siglip import SiglipVisionModel from .modeling_utils import register_auto_model DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1' @@ -403,6 +407,169 @@ def determine_non_vision_query_lengths(input_ids: torch.LongTensor, pad_id: int, return non_vision_query_lengths +# Copied from HyperCLOVAX-SEED-Vision-Instruct-3B/modeling_hyperclovax.py +class HCXVisionCAbstractor(nn.Module): + """ + This module is based on C-Abstractor, whose license is under apache-2.0. + You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py + and we made necessary modifications. + """ + + def __init__( + self, + num_queries: int, + num_input_tokens: int, + encoder_hidden_size: int, + hidden_size: int, + output_hidden_size: int, + pos_emb: bool = True, + prenorm: bool = False, + ): + super().__init__() + self.num_input_tokens = num_input_tokens + self.output_hidden_size = output_hidden_size + + # Positional embedding + if pos_emb: + self.pos_emb = torch.nn.Parameter( + torch.zeros(1, num_input_tokens, encoder_hidden_size)) + self.pos_emb.data.normal_(mean=0.0, std=0.02) + else: + self.pos_emb = None + + # (Optional) Pre-normalization layer + from timm.layers import LayerNorm + if prenorm: + self.prenorm = LayerNorm(encoder_hidden_size) + else: + self.prenorm = None + + self.build_net(num_queries, encoder_hidden_size, hidden_size, + output_hidden_size) + self.dtype = next(self.parameters()).dtype + + def forward( + self, + x: torch.Tensor, + num_queries_vis_abstractors: Optional[List[List[int]]] = None, + num_grids: Optional[List[int]] = None, + ) -> torch.Tensor: + """ + Args: + x: (B, L, encoder_hidden_size) tensor from the visual backbone (e.g. CLIP visual encoder), including cls token. + """ + if self.prenorm is not None: + x = self.prenorm(x) + + if self.pos_emb is not None: + x = x + self.pos_emb + + x = self._forward( + x, + num_queries_vis_abstractors=num_queries_vis_abstractors, + num_grids=num_grids, + ) # (B, L, output_hidden_size) + + return x + + def _forward( + self, + x: torch.Tensor, + num_queries_vis_abstractors: Optional[List[List[int]]] = None, + num_grids: Optional[List[int]] = None, + ) -> torch.Tensor: + + # x: [B, L, dim] + B, L, dim = x.shape + hw = int(L**0.5) + x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) + + if num_queries_vis_abstractors is not None: + assert num_grids is not None + return self._forward_adaptive_num_query( + x, num_queries_vis_abstractors, num_grids) + + x = self.net(x) + x = rearrange(x, "b d h w -> b (h w) d") + x = self.readout(x) + return x + + def _forward_adaptive_num_query( + self, + x: torch.Tensor, + num_queries_vis_abstractors: Optional[List[List[int]]] = None, + num_grids: Optional[List[int]] = None, + ) -> List[torch.Tensor]: + # self.net is consisted by 3 layers (s1, sampler, s2) + assert len(self.net) == 3 + + x = self.net[0](x) # s1 + new_x = [] + for i, num_queries in enumerate(num_queries_vis_abstractors): + hw = int(num_queries**0.5) + sampler = nn.AdaptiveAvgPool2d((hw, hw)) + out = sampler(x[num_grids[i]:num_grids[i + 1], :]) + out = self.net[2](out) # s2 + + out = rearrange(out, "b d h w -> b (h w) d") + out = self.readout(out) + + new_x.append(out) + return new_x + + def build_net( + self, + n_queries: int, + encoder_hidden_size: int, + hidden_size: int, + output_hidden_size: int, + depth: int = 3, + mlp_depth: int = 2, + ): + assert (n_queries**0.5).is_integer( + ), f"n_queries must be square number. n_queries: {n_queries}" + hw = int(n_queries**0.5) + from timm.layers import LayerNorm2d + from timm.models.regnet import RegStage + + # RegBlock = ResBlock + SE + RegBlock = partial( + RegStage, + stride=1, + dilation=1, + act_layer=nn.SiLU, + norm_layer=LayerNorm2d, + ) + + s1 = RegBlock( + depth, + encoder_hidden_size, + hidden_size, + ) + sampler = nn.AdaptiveAvgPool2d((hw, hw)) + s2 = RegBlock( + depth, + hidden_size, + hidden_size, + ) + + self.net = nn.Sequential(s1, sampler, s2) + self.readout = self.build_mlp(mlp_depth, hidden_size, + output_hidden_size) + + def build_mlp( + self, + depth: int, + hidden_size: int, + output_hidden_size: int, + ): + layers = [nn.Linear(hidden_size, output_hidden_size)] + for _ in range(1, depth): + layers.append(nn.SiLU()) + layers.append(nn.Linear(output_hidden_size, output_hidden_size)) + return nn.Sequential(*layers) + + class HCXVisionInputProcessor(InputProcessor): def __init__(self, @@ -423,11 +590,8 @@ def __init__(self, model_path, trust_remote_code=trust_remote_code, use_fast=self.use_fast) - self.tllm_image_token_id = self.pretrained_config.language_config[ + self.tllm_multimodal_token_id = self.pretrained_config.language_config[ "vocab_size"] + 1 - if DISAGG: - self.mm_encoder = HCXVisionModel(self.pretrained_config, - skip_processor=True) def _post_process(self, input_ids: torch.Tensor, @@ -476,7 +640,7 @@ def _post_process(self, batch_idx, input_start + token_len:input_start + token_len + vision_query_lengths[batch_idx][multi_img_idx], - ] = self.tllm_image_token_id + ] = self.tllm_multimodal_token_id input_start += token_len + vision_query_lengths[batch_idx][ multi_img_idx] @@ -531,55 +695,85 @@ def __call__( if not preprocessed_image: return fused_input_ids.to(torch.int32).tolist(), {} - if DISAGG: - mm_embeds = self.mm_encoder.forward(preprocessed_image) - mm_embeds = torch.cat(mm_embeds, dim=0) - else: - # NOTE: For now, I am using "mm_embeding" in tensor format to send the image data to the model. - # CASE 1: Sending raw image data - if isinstance(images[0], Image.Image): - images = [torch.from_numpy(np.array(image)) for image in images] - mm_embeds = torch.stack(images, dim=0) - - # NOTE: After refactoring the llmRequest, we can use preprocessed_image['pixel_values'] to send the image data to the model. - # CASE 2: Sending preprocessed image data - # mm_embeds = torch.cat(preprocessed_image['pixel_values'][0], - # dim=0) - + multimodal_data = {} + multimodal_data["image"] = { + "pixel_values": + torch.stack(preprocessed_image['pixel_values'][0], dim=0).to( + torch.bfloat16 + ), #TODO change the pixel_values into the Shared Tensor + "image_sizes": + preprocessed_image.get('image_sizes', None), + "is_videos": + preprocessed_image.get('is_videos', None), + "num_queries_vis_abstractors": + preprocessed_image.get('num_queries_vis_abstractors', None), + "num_queries_vis_abstractors_slow": + preprocessed_image.get('num_queries_vis_abstractors_slow', None), + "first_last_frames_slows": + preprocessed_image.get('first_last_frames_slows', None), + } return fused_input_ids.to(torch.int32).tolist(), { - "mm_embedding": mm_embeds, + "multimodal_data": multimodal_data } class HCXVisionModel: - def __init__(self, - pretrained_config: PretrainedConfig, - skip_processor: bool = False): + def __init__(self, model_config: ModelConfig[PretrainedConfig]): - self.pretrained_config = pretrained_config + self.pretrained_config = model_config.pretrained_config self.vision_config = self.pretrained_config.vision_config model_path = self.pretrained_config._name_or_path - - # TODO: Remove this when we refactor LlmRequest - # NOTE: trust_remote_code can be removed once we refactor LlmRequest - self.skip_processor = skip_processor - if not self.skip_processor: - self.processor = AutoProcessor.from_pretrained( - model_path, trust_remote_code=True, use_fast=True) - - # NOTE: There is no way of importing mm_projector, HCXVisionCAbstractor from HF. So, can not do the sharded_loading. - # NOTE: trust_rmemote_code can be removed once we change the model into TRT-LLM's format - model = transformers.AutoModelForCausalLM.from_pretrained( - model_path, trust_remote_code=True) - model.eval() - self.device = 'cuda' - - # TODO: Convert to TRT-LLM's SIGLIP - self.vision_model = model.vision_model.to(self.device) - self.mm_projector = model.mm_projector.to(self.device) - self.image_newline = model.image_newline.to(self.device) + self.device = f"cuda:{model_config.mapping.rank}" + + hf_model_config = AutoConfig.from_pretrained(model_path, + trust_remote_code=True) + vision_model_type = hf_model_config.vision_config["model_type"] + vision_config = CONFIG_MAPPING[vision_model_type]( + **hf_model_config.vision_config) + self.dtype = vision_config.torch_dtype + module_dict = nn.ModuleDict({ + "vision_model": + AutoModel.from_config(vision_config, trust_remote_code=True), + "mm_projector": + HCXVisionCAbstractor( + num_queries=hf_model_config.num_queries_vis_abstractor, + num_input_tokens=(vision_config.image_size // + vision_config.patch_size)**2, + encoder_hidden_size=vision_config.hidden_size, + hidden_size=vision_config.hidden_size, + output_hidden_size=hf_model_config. + language_config["hidden_size"], + pos_emb=hf_model_config.proj_pos_emb, + prenorm=hf_model_config.proj_prenorm, + ), + }) + + module_dict.register_parameter( + "image_newline", + nn.Parameter( + torch.empty(hf_model_config.language_config["hidden_size"]))) + + missing_keys, _ = load_sharded_checkpoint(module_dict, + model_path, + strict=False) + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + hf_vision_model = module_dict["vision_model"].to(self.dtype) + hf_mm_projector = module_dict["mm_projector"].to(self.dtype).to( + self.device) + hf_image_newline = module_dict.image_newline.to(self.dtype).to( + self.device) + + vision_model_config = ModelConfig(pretrained_config=vision_config, + attn_backend="TRTLLM") + + # Model related lines + self.vision_model = SiglipVisionModel(vision_model_config).to( + self.device).to(self.dtype) + self.vision_model.load_weights(hf_vision_model.state_dict()) + self.mm_projector = hf_mm_projector.eval() + self.image_newline = hf_image_newline self.unpad = self.pretrained_config.unpad self.use_nth_layer = self.pretrained_config.use_nth_layer @@ -605,56 +799,49 @@ def _init_possible_resolutions(self, config: PretrainedConfig): return possible_resolutions def _to_device( - self, input_tensor: Union[torch.Tensor, - List]) -> Union[torch.Tensor, List]: - if isinstance(input_tensor, list): + self, input_tensor: Union[torch.Tensor, List, None] + ) -> Union[torch.Tensor, List, None]: + if input_tensor is None: + return None + elif isinstance(input_tensor, list): return [self._to_device(item) for item in input_tensor] elif isinstance(input_tensor, torch.Tensor): return input_tensor.to(self.device) - # TODO: Remove this when we refactor LlmRequuest - def _preprocess(self, mm_data: List[Any]) -> Dict[str, List[Any]]: - preprocessed_image_list = [] - - for images in mm_data: - images = torch.unbind(images, dim=0) - preprocessed_image = self.processor( - images=images, - is_video_list=[False] * len(images), - ) - - # NOTE: The HCXVisionInputProcessor makes pixel_vlues to CPU values even though use_fast = True. - # So, we need to transfer them to GPU. - preprocessed_image["pixel_values"] = self._to_device( - preprocessed_image["pixel_values"]) - - preprocessed_image_list.append(preprocessed_image) - - return { - key: [d[key][0] for d in preprocessed_image_list] - for key in preprocessed_image_list[0].keys() + def _parse_and_batch_multimodal_data( + self, multimodal_params: List[MultimodalParams] + ) -> Tuple[List[torch.Tensor], Dict[str, List[Any]]]: + """Parse and batch multimodal data from MultimodalParams objects.""" + pixel_values = [ + list( + torch.unbind( + multimodal_param.multimodal_data["image"]["pixel_values"], + dim=0)) for multimodal_param in multimodal_params + ] + mm_extra_data = { + key: [ + multimodal_param.multimodal_data["image"][key][0] + for multimodal_param in multimodal_params + ] + for key in multimodal_params[0].multimodal_data["image"].keys() } + return pixel_values, mm_extra_data - def forward(self, mm_data: Union[List[Any], Dict[str, Any]]): - if not self.skip_processor: - # NOTE: This should be done in the input processor and got the preprocessed_image metadata from request level. - # But before refactoring the llmRequest, we are re-doing inputprocessor here. - preprocessed_image = self._preprocess(mm_data) - else: - # NOTE: When we refactor the llmRequest, we will get the extra_mm_data from mm_data, and need to make it as preprocessed_image. - preprocessed_image = mm_data - preprocessed_image["pixel_values"] = self._to_device( - preprocessed_image["pixel_values"]) - - pixel_values = preprocessed_image.get("pixel_values", None) - image_sizes = preprocessed_image.get("image_sizes", None) - is_videos = preprocessed_image.get("is_videos", None) - num_queries_vis_abstractors = preprocessed_image.get( + @torch.inference_mode() + def forward(self, multimodal_params: List[MultimodalParams]): + + pixel_values, mm_extra_data = self._parse_and_batch_multimodal_data( + multimodal_params) + pixel_values = self._to_device( + pixel_values) # TODO: remove this once we have the shared tensor + image_sizes = mm_extra_data.get("image_sizes", None) + is_videos = mm_extra_data.get("is_videos", None) + num_queries_vis_abstractors = mm_extra_data.get( "num_queries_vis_abstractors", None) - num_queries_vis_abstractors_slow = preprocessed_image.get( + num_queries_vis_abstractors_slow = mm_extra_data.get( "num_queries_vis_abstractors_slow", None) - first_last_frames_slows = preprocessed_image.get( - "first_last_frames_slows", None) + first_last_frames_slows = mm_extra_data.get("first_last_frames_slows", + None) len_pixel_values = [len(pixel_value) for pixel_value in pixel_values] concat_pixel_values = torch.cat(list(chain(*pixel_values)), @@ -681,15 +868,15 @@ def forward(self, mm_data: Union[List[Any], Dict[str, Any]]): device=concat_pixel_values.device, ) chunk = torch.cat([chunk, dummy], dim=0) - + attn_metadata = self.vision_model.prepare_attn_metadata( + chunk.shape[0]) if self.use_nth_layer == -1: self.vision_model.vision_model.post_layernorm = nn.Identity() - outs = self.vision_model(chunk) - outs = outs.last_hidden_state[:, visual_token_idx:] + outs = self.vision_model(chunk, attn_metadata=attn_metadata) + outs = outs[:, visual_token_idx:] else: - outs = self.vision_model(chunk, output_hidden_states=True) - outs = outs.hidden_states[self.use_nth_layer][:, - visual_token_idx:] + outs = self.vision_model(chunk, attn_metadata=attn_metadata) + outs = outs[self.use_nth_layer][:, visual_token_idx:] image_forward_outs_chunks.append(outs) image_forward_outs = torch.cat(image_forward_outs_chunks, dim=0).to( @@ -701,8 +888,6 @@ def forward(self, mm_data: Union[List[Any], Dict[str, Any]]): if is_videos is not None: is_videos = list(chain(*is_videos)) group_ids = None - image_forward_outs = image_forward_outs.to( - dtype=self.mm_projector.dtype) image_forward_outs = self.mm_projector(image_forward_outs) else: ( @@ -719,9 +904,6 @@ def forward(self, mm_data: Union[List[Any], Dict[str, Any]]): is_videos, first_last_frames_slows, ) - - image_forward_outs = image_forward_outs.to( - dtype=self.mm_projector.dtype) image_forward_outs = self.mm_projector( image_forward_outs, num_queries_vis_abstractors=num_queries_vis_abstractors, @@ -790,16 +972,14 @@ def __init__(self, model_config: ModelConfig): self.model_config = model_config if hasattr(self, "llm"): return - if not DISAGG: - self.mm_encoder = HCXVisionModel(model_config.pretrained_config) - + self.mm_encoder = HCXVisionModel(model_config) llm_model_config = copy.deepcopy(model_config) llm_model_config.pretrained_config = PretrainedConfig.from_dict( llm_model_config.pretrained_config.language_config) self.llm = AutoModelForCausalLM.from_config(llm_model_config) - self.model_dtype = getattr(config, "torch_dtype", torch.float16) + self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16) logger.info(f"{self.dtype=} {self.model_dtype=}") self.post_config() self.is_loaded = True @@ -843,17 +1023,19 @@ def forward( f"num_context_requests: {num_context_requests}, num_generation_requests: {num_generation_requests}" ) - mm_data = kwargs.get("multi_modal_data", []) + multimodal_params = kwargs.get("multimodal_params", []) mm_embeds = [] - if len(mm_data) > 0: - assert len( - mm_data - ) == num_context_requests, f"Number of multimodal tensors ({len(mm_data)}) should be equal to number of context requests ({num_context_requests}) in the batch." - if DISAGG: - # NOTE: In the DISAGG, we are assuming we get the mm_embeds from the llmRequest. - mm_embeds = mm_data + if len(multimodal_params) > 0: + assert len(multimodal_params) == num_context_requests == len( + multimodal_params + ), f"Number of multimodal tensors ({len(multimodal_params)}) should be equal to number of context requests ({num_context_requests}) in the batch." + if not DISAGG: + mm_embeds = self.mm_encoder.forward(multimodal_params) else: - mm_embeds = self.mm_encoder.forward(mm_data) + mm_embeds = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens, input_ids, mm_embeds) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 8bd0cd80e6c..8f2e2696609 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -851,7 +851,10 @@ def __call__( mm_embeds = self.encoder.multi_modal_projector(mm_embeds) # for fuse_input_embeds token_ids[token_ids == self.image_token_index] = self.vocab_size + 1 - return token_ids.tolist(), {"mm_embedding": mm_embeds} + + multimodal_data = {} + multimodal_data["multimodal_embedding"] = mm_embeds + return token_ids.tolist(), {"multimodal_data": multimodal_data} else: return processed["input_ids"].squeeze().tolist(), {} @@ -882,8 +885,12 @@ def forward( spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: - mm_embed = kwargs.get("multi_modal_data", []) - if mm_embed: + multimodal_params = kwargs.get("multimodal_params", []) + if multimodal_params: + mm_embed = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] _, inputs_embeds = fuse_input_embeds(self.model.embed_tokens, input_ids, mm_embed) return super().forward(attn_metadata, diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 967555dcb38..851b350b363 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -210,8 +210,10 @@ def __call__( mm_features = torch.stack( [self._process(tensor) for tensor in mm_tensor]) fused_input_ids, mm_features = self._postprocess(input_ids, mm_features) + multimodal_data = {} + multimodal_data["multimodal_embedding"] = mm_features return fused_input_ids.to(torch.int32).tolist(), { - "mm_embedding": mm_features + "multimodal_data": multimodal_data } @@ -271,7 +273,11 @@ def forward( num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations logger.debug(f"{num_context_requests=}, {num_generation_requests=}") - mm_embed = kwargs.get("multi_modal_data", []) + multimodal_params = kwargs.get("multimodal_params", []) + mm_embed = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] assert mm_embed == [] or len( mm_embed ) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests" diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index c52b260e4cc..1f9bfb80325 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Dict, List, Optional, Tuple +import os +from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig, @@ -7,6 +8,8 @@ Qwen2VLForConditionalGeneration) from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from tensorrt_llm.inputs.multimodal import MultimodalParams + from ...functional import RopeEmbeddingUtils, RotaryScalingType from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, register_input_processor) @@ -18,6 +21,8 @@ from .modeling_multimodal_utils import fuse_input_embeds from .modeling_utils import register_auto_model +DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1' + class Qwen2VLInputProcessorBase(InputProcessor): @@ -28,25 +33,17 @@ def __init__(self, trust_remote_code: bool = True): self.model_config = model_config self.tokenizer = tokenizer + # TODO: change to True and also change the according test result self.use_fast = False + self.device = 'cuda' self.processor = AutoProcessor.from_pretrained( model_path, use_fast=self.use_fast, trust_remote_code=trust_remote_code) - # NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM. - model = self.get_model_class().from_pretrained( - model_path, - torch_dtype=model_config.torch_dtype, - attn_implementation='flash_attention_2') - self.device = 'cuda' - self.visual = model.visual.to(self.device) + self.tllm_multimodal_token_id = self.model_config.vocab_size + 1 self._post_init_() - @classmethod - def get_model_class(cls) -> type[PreTrainedModel]: - raise NotImplementedError() - @classmethod def get_rope_index( cls, @@ -284,34 +281,13 @@ def _preprocess(self, text: dict[str, any], mm_data: dict[str, any], return_tensors='pt', **mm_processor_kwargs) - def _process(self, pixel_values: torch.Tensor, - pixel_values_videos: torch.Tensor, - image_grid_thw: torch.Tensor, - video_grid_thw: torch.Tensor) -> torch.Tensor: - embeds = [] - - if pixel_values is not None: - pixel_values = pixel_values.to(self.visual.dtype) - embeds.append(self.visual(pixel_values, grid_thw=image_grid_thw)) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.to(self.visual.dtype) - embeds.append( - self.visual(pixel_values_videos, grid_thw=video_grid_thw)) - - if embeds: - return torch.cat(embeds, dim=1) - return None - def _postprocess(self, input_ids: torch.IntTensor) -> torch.IntTensor: - # NOTE: Qwen2-VL's input processor is doing all the work for fusing input_ids with mm_tokens. So, we just replace mm_tokens with expanded out-of-vocab ids - + # NOTE: Qwen2-VL's input processor is doing all the work for fusing input_ids with mm_tokens. + # So, we just replace mm_tokens with expanded out-of-vocab ids masks = (input_ids == self.model_config.image_token_id) | ( input_ids == self.model_config.vision_token_id) | ( input_ids == self.model_config.video_token_id) - cumulative_counts = masks.cumsum(dim=-1) - values = (self.model_config.vocab_size - 1) + cumulative_counts - input_ids[masks] = values[masks] + input_ids[masks] = self.tllm_multimodal_token_id return input_ids def get_mrope_config( @@ -349,7 +325,8 @@ def get_mrope_config( concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], -1) mrope_config = {} mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to('cpu') - mrope_config['mrope_position_deltas'] = mrope_position_deltas.to('cpu') + mrope_config['mrope_position_deltas'] = mrope_position_deltas.to( + 'cpu').to(torch.int32) return mrope_config @torch.inference_mode() @@ -363,43 +340,147 @@ def __call__( processed_inputs = self._preprocess(text_prompt, mm_data, mm_processor_kwargs).to(self.device) - if mm_data: - mm_features = self._process( - processed_inputs.get('pixel_values', None), - processed_inputs.get('pixel_values_videos', None), - processed_inputs.get('image_grid_thw', None), - processed_inputs.get('video_grid_thw', None)) - else: - mm_features = None - input_ids = processed_inputs['input_ids'] + if not mm_data: + fused_input_ids = processed_inputs['input_ids'] + return fused_input_ids.to(torch.int32).tolist(), {} + + pixel_values = processed_inputs.get('pixel_values', None) + pixel_values_videos = processed_inputs.get('pixel_values_videos', None) + assert pixel_values is not None or pixel_values_videos is not None, "No multimodal data found" + multimodal_data = {} + if pixel_values is not None: + multimodal_data["image"] = { + "pixel_values": pixel_values, + "image_grid_thw": processed_inputs.get('image_grid_thw') + } + if pixel_values_videos is not None: + multimodal_data["video"] = { + "pixel_values_videos": pixel_values_videos, + "video_grid_thw": processed_inputs.get('video_grid_thw') + } + + input_ids = processed_inputs['input_ids'] + # TODO: We can move this to the LLM-side. mrope_config = self.get_mrope_config( input_ids, processed_inputs.get('image_grid_thw', None), processed_inputs.get('video_grid_thw', None), processed_inputs.get('attention_mask', None), processed_inputs.get('second_per_grid_ts', None)) + multimodal_data["mrope_config"] = mrope_config fused_input_ids = self._postprocess(input_ids[0]) return fused_input_ids.to(torch.int32).tolist(), { - "mm_embedding": mm_features, - "mrope_config": mrope_config + "multimodal_data": multimodal_data, } -class Qwen2VLInputProcessor(Qwen2VLInputProcessorBase): +class Qwen2VisionModelBase: - @classmethod - def get_model_class(cls): - return Qwen2VLForConditionalGeneration + def __init__(self, model_config: ModelConfig[PretrainedConfig], + model_class: type[PreTrainedModel]): + self.pretrained_config = model_config.pretrained_config + self.device = f"cuda:{model_config.mapping.rank}" + model_path = self.pretrained_config._name_or_path + # TODO: Change the model class to TRT-LLM's Qwen2VisionModel + # Currently, copying vision encoder on all devices. + # NOTE: Using attn_implementation='flash_attention_2' to avoid the issue of vision model's GPU OOM. + model = model_class.from_pretrained( + model_path, + torch_dtype=self.pretrained_config.torch_dtype, + attn_implementation='flash_attention_2').eval() + self.visual = model.visual.to(self.device) -class Qwen2_5_VLInputProcessor(Qwen2VLInputProcessorBase): + def _to_device( + self, input_tensor: Union[torch.Tensor, List, None] + ) -> Union[torch.Tensor, List, None]: + if input_tensor is None: + return None + elif isinstance(input_tensor, list): + return [self._to_device(item) for item in input_tensor] + elif isinstance(input_tensor, torch.Tensor): + return input_tensor.to(self.device) + + def _parse_and_batch_multimodal_data( + self, multimodal_params: List[MultimodalParams] + ) -> Tuple[Dict[str, Any], Dict[str, List[Any]]]: + + pixel_values_list = [] + pixel_values_videos_list = [] + image_grid_thw_list = [] + video_grid_thw_list = [] + + for multimodal_param in multimodal_params: + # Process images if present + if multimodal_param.multimodal_data.get("image") is not None: + pixel_values_list.append( + multimodal_param.multimodal_data["image"]["pixel_values"]) + image_grid_thw_list.append( + multimodal_param.multimodal_data["image"]["image_grid_thw"]) + + # Process videos if present + if multimodal_param.multimodal_data.get("video") is not None: + pixel_values_videos_list.append( + multimodal_param.multimodal_data["video"] + ["pixel_values_videos"]) + video_grid_thw_list.append( + multimodal_param.multimodal_data["video"]["video_grid_thw"]) + + # Concatenate tensors + mm_content_dict = {} + if pixel_values_list: + mm_content_dict["pixel_values"] = torch.cat( + pixel_values_list, + dim=0) if len(pixel_values_list) > 1 else pixel_values_list[0] + if pixel_values_videos_list: + mm_content_dict["pixel_values_videos"] = torch.cat( + pixel_values_videos_list, + dim=0) if len(pixel_values_videos_list + ) > 1 else pixel_values_videos_list[0] + + # Prepare extra data + mm_extra_data = {} + if image_grid_thw_list: + mm_extra_data["image_grid_thw"] = torch.cat( + image_grid_thw_list, dim=0) if len( + image_grid_thw_list) > 1 else image_grid_thw_list[0] + if video_grid_thw_list: + mm_extra_data["video_grid_thw"] = torch.cat( + video_grid_thw_list, dim=0) if len( + video_grid_thw_list) > 1 else video_grid_thw_list[0] + + return mm_content_dict, mm_extra_data - @classmethod - def get_model_class(cls): - return Qwen2_5_VLForConditionalGeneration + @torch.inference_mode() + def forward(self, multimodal_params: List[MultimodalParams]): + + mm_content_data, mm_extra_data = self._parse_and_batch_multimodal_data( + multimodal_params) + pixel_values = mm_content_data.get("pixel_values", None) + pixel_values_videos = mm_content_data.get("pixel_values_videos", None) + + image_grid_thw = mm_extra_data.get("image_grid_thw", None) + video_grid_thw = mm_extra_data.get("video_grid_thw", None) + + embeds = [] + if pixel_values is not None: + pixel_values = self._to_device( + pixel_values + ) # TODO: remove this once we have the shared tensor + image_grid_thw = self._to_device(image_grid_thw) + pixel_values = pixel_values.to(self.visual.dtype) + embeds.append(self.visual(pixel_values, grid_thw=image_grid_thw)) + + if pixel_values_videos is not None: + pixel_values_videos = self._to_device(pixel_values_videos) + video_grid_thw = self._to_device(video_grid_thw) + pixel_values_videos = pixel_values_videos.to(self.visual.dtype) + embeds.append( + self.visual(pixel_values_videos, grid_thw=video_grid_thw)) + return embeds class Qwen2VLModelBase(PreTrainedModel): @@ -413,7 +494,7 @@ def __init__( model_config.pretrained_config.rope_scaling['type'] = 'mrope' config = model_config.pretrained_config - assert model_config.attn_backend == 'TRTLLM', "Qwen2VL only supports TRTLLM backend now" + assert model_config.attn_backend == 'TRTLLM', "Qwen2/2.5-VL only supports TRTLLM backend now" super().__init__(config) self.model_config = model_config @@ -440,6 +521,48 @@ def post_config(self): self.config = self.llm.config self.model_config.pretrained_config = self.llm.config + def _parse_and_concat_mrope_config( + self, multimodal_params: List[MultimodalParams], + num_context_requests: int, + num_generation_requests: int) -> dict[str, torch.Tensor]: + """ + Parse and concatenate mrope configuration from multimodal parameters. + """ + + mrope_configs = [ + param.multimodal_data.get('mrope_config') + for param in multimodal_params if param.multimodal_data + and param.multimodal_data.get('mrope_config') + ] + if not mrope_configs: + return {} + + batched_mrope_config = {} + if num_context_requests > 0: + cos_sin_tensors = [ + config['mrope_rotary_cos_sin'] + for config in mrope_configs[:num_context_requests] + if config.get('mrope_rotary_cos_sin') is not None + ] + if cos_sin_tensors: + batched_mrope_config['mrope_rotary_cos_sin'] = torch.cat( + cos_sin_tensors, dim=0) + + if num_generation_requests > 0: + generation_mrope_configs = mrope_configs[ + -num_generation_requests:] if len( + mrope_configs) >= num_generation_requests else mrope_configs + position_delta_tensors = [ + config['mrope_position_deltas'] + for config in generation_mrope_configs + if config.get('mrope_position_deltas') is not None + ] + if position_delta_tensors: + batched_mrope_config['mrope_position_deltas'] = torch.cat( + position_delta_tensors, dim=0) + + return batched_mrope_config + @torch.inference_mode() def forward( self, @@ -458,26 +581,25 @@ def forward( f"num_context_requests: {num_context_requests}, num_generation_requests: {num_generation_requests}" ) - mm_embed = kwargs.get("multi_modal_data", []) + multimodal_params = kwargs.get("multimodal_params", []) + mm_embeds = [] + mrope_config = {} - error_msg = "Number of multimodal features (if provided) should be equal to number of context requests" - assert mm_embed == [] or len( - mm_embed) == num_context_requests, error_msg + if len(multimodal_params) > 0: + if not DISAGG: + mm_embeds = self.mm_encoder.forward( + multimodal_params[:num_context_requests]) + else: + mm_embeds = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] + mrope_config = self._parse_and_concat_mrope_config( + multimodal_params, num_context_requests, + num_generation_requests) input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens, - input_ids, mm_embed) - - mrope_config = kwargs.get("mrope_config", {}) - if mrope_config: - if mrope_rotary_cos_sin := mrope_config.get('mrope_rotary_cos_sin'): - mrope_config['mrope_rotary_cos_sin'] = torch.cat( - mrope_rotary_cos_sin, dim=0) - - if mrope_position_deltas := mrope_config.get( - 'mrope_position_deltas'): - mrope_config['mrope_position_deltas'] = torch.cat( - mrope_position_deltas, dim=0) - + input_ids, mm_embeds) output_prob = self.llm.forward( attn_metadata=attn_metadata, input_ids=input_ids, @@ -485,17 +607,30 @@ def forward( inputs_embeds=input_embeds, return_context_logits=return_context_logits, mrope_config=mrope_config) + logger.debug(f'output shape: {output_prob.shape}') return output_prob @register_auto_model("Qwen2VLForConditionalGeneration") -@register_input_processor(Qwen2VLInputProcessor, model_type="qwen2_vl") +@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_vl") class Qwen2VLModel(Qwen2VLModelBase): - pass + + def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, + **kwargs): + if not DISAGG: + self.mm_encoder = Qwen2VisionModelBase( + model_config, Qwen2VLForConditionalGeneration) + super().__init__(model_config, *args, **kwargs) @register_auto_model("Qwen2_5_VLForConditionalGeneration") -@register_input_processor(Qwen2_5_VLInputProcessor, model_type="qwen2_5_vl") +@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_5_vl") class Qwen2_5_VLModel(Qwen2VLModelBase): - pass + + def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, + **kwargs): + if not DISAGG: + self.mm_encoder = Qwen2VisionModelBase( + model_config, Qwen2_5_VLForConditionalGeneration) + super().__init__(model_config, *args, **kwargs) diff --git a/tensorrt_llm/_torch/models/modeling_vila.py b/tensorrt_llm/_torch/models/modeling_vila.py index 217ec9c388a..c27a88abf5f 100644 --- a/tensorrt_llm/_torch/models/modeling_vila.py +++ b/tensorrt_llm/_torch/models/modeling_vila.py @@ -1107,8 +1107,10 @@ def __call__( ) # use_fast uses Pytorch GPU preprocessing, otherwise uses PIL CPU preprocessing mm_features = self._process(mm_tensor, block_sizes) fused_input_ids, mm_features = self._postprocess(input_ids, mm_features) + multimodal_data = {} + multimodal_data["multimodal_embedding"] = mm_features return fused_input_ids.to(torch.int32).tolist(), { - "mm_embedding": mm_features + "multimodal_data": multimodal_data } @@ -1161,7 +1163,11 @@ def forward( """ num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations - mm_embed = kwargs.get("multi_modal_data", []) + multimodal_params = kwargs.get("multimodal_params", []) + mm_embed = [ + multimodal_param.multimodal_data["multimodal_embedding"] + for multimodal_param in multimodal_params + ] assert mm_embed == [] or len( mm_embed diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 8e33920f6ef..461c5de941e 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -281,6 +281,8 @@ def __init__( **kwargs): self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", None) + # Multimodal data + self.py_multimodal_data = kwargs.pop("py_multimodal_data", None) super().__init__( *args, client_id=client_id, @@ -400,6 +402,22 @@ def executor_request_to_llm_request( stop_words_list = convert_wordlist( executor_request.stop_words) if executor_request.stop_words else None + # Extract multimodal fields from executor request + multimodal_hashes = None + multimodal_positions = None + multimodal_lengths = None + if executor_request.multimodal_input is not None: + multimodal_hashes = executor_request.multimodal_input.multimodal_hashes + multimodal_positions = executor_request.multimodal_input.multimodal_positions + multimodal_lengths = executor_request.multimodal_input.multimodal_lengths + + # Extract mrope fields + mrope_rotary_cos_sin = None + mrope_position_deltas = None + if executor_request.mrope_config is not None: + mrope_rotary_cos_sin = executor_request.mrope_config.mrope_rotary_cos_sin + mrope_position_deltas = executor_request.mrope_config.mrope_position_deltas + llm_request = LlmRequest( request_id=req_id, max_new_tokens=executor_request.max_tokens, @@ -419,24 +437,18 @@ def executor_request_to_llm_request( is None else executor_request.prompt_tuning_config.embedding_table, prompt_vocab_size=None if executor_request.prompt_tuning_config is None else executor_request.prompt_tuning_config.embedding_table.shape[0], - multimodal_hashes=None if executor_request.multimodal_input is None else - executor_request.multimodal_input.multimodal_hashes, - multimodal_positions=None if executor_request.multimodal_input is None - else executor_request.multimodal_input.multimodal_positions, - multimodal_lengths=None if executor_request.multimodal_input is None - else executor_request.multimodal_input.multimodal_lengths, - multimodal_embedding=None if executor_request.multimodal_embedding - is None else executor_request.multimodal_embedding, + multimodal_hashes=multimodal_hashes, + multimodal_positions=multimodal_positions, + multimodal_lengths=multimodal_lengths, + multimodal_embedding=executor_request.multimodal_embedding, lora_task_id=executor_request.lora_config.task_id if executor_request.lora_config is not None else None, lora_weights=executor_request.lora_config.weights if executor_request.lora_config is not None else None, lora_config=executor_request.lora_config.config if executor_request.lora_config is not None else None, - mrope_rotary_cos_sin=None if executor_request.mrope_config is None else - executor_request.mrope_config.mrope_rotary_cos_sin, - mrope_position_deltas=None if executor_request.mrope_config is None else - executor_request.mrope_config.mrope_position_deltas, + mrope_rotary_cos_sin=mrope_rotary_cos_sin, + mrope_position_deltas=mrope_position_deltas, lookahead_config=None, return_log_probs=executor_request.output_config.return_log_probs, return_context_logits=executor_request.output_config. @@ -460,6 +472,7 @@ def executor_request_to_llm_request( if executor_request.client_id is not None else req_id, priority=0.5, llm_request_type=llm_request_type, - context_phase_params=executor_request.context_phase_params) - + context_phase_params=executor_request.context_phase_params, + py_multimodal_data=getattr(executor_request, "py_multimodal_data", + None)) return llm_request diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7603e724c70..673ac0c6a2c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -10,7 +10,6 @@ import traceback import weakref from abc import ABC, abstractmethod -from collections import defaultdict from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple @@ -27,6 +26,7 @@ local_mpi_size, nvtx_range, release_gc, torch_dtype_to_str, trace_func) from tensorrt_llm.bindings.executor import GuidedDecodingConfig +from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig from tensorrt_llm.mapping import Mapping @@ -1176,10 +1176,9 @@ def _prepare_tp_inputs( gather_ids = [] position_ids = [] # per sequence num_cached_tokens_per_seq = [] # per sequence - multi_modal_data = [] draft_tokens = [] draft_lens = [] - mrope_config = defaultdict(list) + multimodal_params_list = [] gen_request_seq_slots = [] # per generation request for request in scheduled_requests.context_requests: @@ -1197,19 +1196,17 @@ def _prepare_tp_inputs( prompt_lengths.append(len(prompt_tokens)) past_seen_token_num = begin_compute num_cached_tokens_per_seq.append(past_seen_token_num) - multimodal_embedding = request.multimodal_embedding - if multimodal_embedding is not None: - multimodal_embedding = multimodal_embedding.pin_memory( - ) if multimodal_embedding.device == 'cpu' else multimodal_embedding - multi_modal_data.append( - multimodal_embedding.to('cuda', non_blocking=True)) - - mrope_rotary_cos_sin = request.mrope_rotary_cos_sin - if mrope_rotary_cos_sin is not None: - mrope_rotary_cos_sin = mrope_rotary_cos_sin.pin_memory( - ) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin - mrope_config['mrope_rotary_cos_sin'].append( - mrope_rotary_cos_sin.to('cuda', non_blocking=True)) + + # Multimodal + multimodal_params = MultimodalParams( + multimodal_data=request.py_multimodal_data) + multimodal_params.to_device("multimodal_data", + "cuda", + pin_memory=True) + + if multimodal_params.has_content(): + multimodal_params_list.append(multimodal_params) + request.py_batch_idx = request.seq_slot num_ctx_requests = len(scheduled_requests.context_requests) @@ -1237,14 +1234,15 @@ def _prepare_tp_inputs( extend_requests.append(request) else: generation_requests.append(request) - - mrope_position_deltas = request.mrope_position_deltas - if mrope_position_deltas is not None: - mrope_position_deltas = torch.tensor([mrope_position_deltas], - dtype=torch.int32, - pin_memory=True) - mrope_config['mrope_position_deltas'].append( - mrope_position_deltas.to('cuda', non_blocking=True)) + # Multimodal + multimodal_params = MultimodalParams( + multimodal_data=request.py_multimodal_data) + multimodal_params.strip_for_generation() + multimodal_params.to_device("multimodal_data", + "cuda", + pin_memory=True) + if multimodal_params.has_content(): + multimodal_params_list.append(multimodal_params) extend_requests += extend_dummy_requests if not self._disable_overlap_scheduler and self.is_spec_decode: @@ -1494,8 +1492,7 @@ def previous_seq_slots_device(): 'position_ids': self.position_ids_cuda[:total_num_tokens].unsqueeze(0), 'inputs_embeds': None, - 'multi_modal_data': multi_modal_data, - 'mrope_config': mrope_config + "multimodal_params": multimodal_params_list, } if bool(lora_params): diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ab501c85be5..bceb97bb55e 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1158,8 +1158,8 @@ def _update_new_active_requests_queue_latency( def _broadcast_new_requests( self, new_requests: List[RequestQueueItem], - py_request_objects: Optional[tuple[str, dict]] = None, - ) -> tuple[List[RequestQueueItem], Optional[tuple[str, dict]]]: + py_request_objects: Optional[dict[str, tuple[str, dict]]] = None, + ) -> tuple[List[RequestQueueItem], Optional[dict[str, tuple[str, dict]]]]: """Broadcasts new_requests and optional Python-only metadata (`py_request_objects`) across pipeline stages. `py_request_objects` is a tuple of (attribute_name, {request_id: object}). """ @@ -1206,8 +1206,12 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: total_max_num_active_requests - total_num_active_requests) if self.dist.rank == 0: - py_request_objects = self._collect_py_objects_from_requests( + py_logits_post_processors = self._collect_py_objects_from_requests( new_requests, "py_logits_post_processors") + py_multimodal_data = self._collect_py_objects_from_requests( + new_requests, "py_multimodal_data") + py_request_objects = tuple( + filter(None, [py_logits_post_processors, py_multimodal_data])) else: py_request_objects = None @@ -1234,9 +1238,9 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp) and self.dist.rank > 0: - attr_name, req_obj_dict = py_request_objects - self._attach_py_objects_to_requests(new_requests, attr_name, - req_obj_dict) + for attr_name, req_obj_dict in py_request_objects: + self._attach_py_objects_to_requests(new_requests, attr_name, + req_obj_dict) if not self.enable_attention_dp: self._update_new_active_requests_queue_latency(new_requests) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index fe218ecdcd5..ec4fc414cd2 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -13,7 +13,7 @@ import numpy as np import torch -from tensorrt_llm.inputs.multimodal import MultimodalInput +from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.logger import logger, set_level from tensorrt_llm.lora_manager import LoraConfig @@ -109,20 +109,17 @@ def abort_request(self, request_id: int) -> None: pass def generate_async( - self, - prompt_token_ids: List[int], - sampling_params: SamplingParams, - query_token_ids: Optional[Union[torch.Tensor, np.ndarray, - list]] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - streaming: bool = False, - multimodal_input: Optional[MultimodalInput] = None, - multimodal_embedding: Optional[list] = None, - mrope_config: Optional[dict] = None, - kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, - disaggregated_params: Optional[DisaggregatedParams] = None, - postproc_params: Optional[PostprocParams] = None + self, + prompt_token_ids: List[int], + sampling_params: SamplingParams, + query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + streaming: bool = False, + kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, + disaggregated_params: Optional[DisaggregatedParams] = None, + postproc_params: Optional[PostprocParams] = None, + multimodal_params: Optional[MultimodalParams] = None, ) -> GenerationResult: """Generate output for the given prompt token ids in the asynchronous mode. Asynchronous generation accepts single prompt only. @@ -144,11 +141,9 @@ def generate_async( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, streaming=streaming, - multimodal_input=multimodal_input, - multimodal_embedding=multimodal_embedding, - mrope_config=mrope_config, kv_cache_retention_config=kv_cache_retention_config, - disaggregated_params=disaggregated_params)) + disaggregated_params=disaggregated_params, + multimodal_params=multimodal_params)) return result def generate( diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index 933c9f43510..655af261042 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -5,7 +5,7 @@ import numpy as np import torch -from tensorrt_llm.inputs.multimodal import MultimodalInput +from tensorrt_llm.inputs.multimodal import MultimodalParams from ..disaggregated_params import DisaggregatedParams from ..llmapi.llm_utils import KvCacheRetentionConfig @@ -82,12 +82,10 @@ def __init__( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, streaming: bool = False, - multimodal_input: Optional[MultimodalInput] = None, - multimodal_embedding: Optional[list] = None, - mrope_config: Optional[dict] = None, kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, disaggregated_params: Optional[DisaggregatedParams] = None, postproc_params: Optional[PostprocParams] = None, + multimodal_params: Optional[MultimodalParams] = None, ): if isinstance(prompt_token_ids, list): self.prompt_token_ids = prompt_token_ids @@ -106,9 +104,7 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request self.streaming = streaming - self.multimodal_input = multimodal_input - self.multimodal_embedding = multimodal_embedding - self.mrope_config = mrope_config + self.multimodal_params = multimodal_params self.kv_cache_retention_config = kv_cache_retention_config self.id: Optional[int] = None self.disaggregated_params = disaggregated_params diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index eeed86283d7..da90fc8fe93 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -376,11 +376,6 @@ def _enqueue_request(self, request: GenerationRequest) -> int: prompt_token_ids = copy.deepcopy(request.prompt_token_ids) prompt_tuning_config = None - multimodal_embedding = None - mrope_config = None - multimodal_input = None - if request.multimodal_embedding is not None: - multimodal_embedding = request.multimodal_embedding if request.prompt_adapter_request is not None: self._load_prompt_adapter(request.prompt_adapter_request) uid = str(request.prompt_adapter_request.adapter_id) @@ -391,15 +386,22 @@ def _enqueue_request(self, request: GenerationRequest) -> int: prompt_token_ids = list(range( vocab_size, vocab_size + pa_length)) + prompt_token_ids - if request.mrope_config is not None: - mrope_config = tllm.MropeConfig(**request.mrope_config) - - if request.multimodal_input is not None: - multimodal_input = tllm.MultimodalInput( - multimodal_hashes=request.multimodal_input.multimodal_hashes, - multimodal_positions=request.multimodal_input. - multimodal_positions, - multimodal_lengths=request.multimodal_input.multimodal_lengths) + # MULTIMODAL + # NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field + # except `multimodal_input` as it needs to go through the C++ runtime. + multimodal_input = None + if request.multimodal_params is not None and request.multimodal_params.has_content( + ): + if request.multimodal_params.multimodal_input is not None: + multimodal_input = tllm.MultimodalInput( + multimodal_hashes=request.multimodal_params. + multimodal_input.multimodal_hashes, + multimodal_positions=request.multimodal_params. + multimodal_input.multimodal_positions, + multimodal_lengths=request.multimodal_params. + multimodal_input.multimodal_lengths) + # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field + request.multimodal_params.multimodal_input = None context_phase_params = None request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION @@ -467,8 +469,9 @@ def _deduce_max_tokens(request: GenerationRequest, lora_config=lora_config, prompt_tuning_config=prompt_tuning_config, multimodal_input=multimodal_input, - multimodal_embedding=multimodal_embedding, - mrope_config=mrope_config, + #NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. + multimodal_embedding=None, + mrope_config=None, logits_post_processor_name=( tllm.Request.BATCHED_POST_PROCESSOR_NAME if request.sampling_params.apply_batched_logits_processor @@ -479,6 +482,10 @@ def _deduce_max_tokens(request: GenerationRequest, context_phase_params=context_phase_params, type=request_type) + if self._is_pytorch_backend and request.multimodal_params is not None: + if request.multimodal_params.multimodal_data is not None: + executor_request.py_multimodal_data = request.multimodal_params.multimodal_data + if self._is_pytorch_backend and request.sampling_params.logits_processor: # For PyTorch backend, we attach logits processors as a dynamic Python attribute # instead of using the C++ binding, since the latter will cause PyCapsule pickling issues. diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index 33d435ef1a7..a6b29a9f018 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -1,7 +1,7 @@ """Multimodal utilities for handling images and other media types in TensorRT-LLM.""" -from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import PIL @@ -82,6 +82,138 @@ def to_tensor(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: torch.tensor(self.multimodal_lengths, dtype=torch.int32)) +@dataclass +class MultimodalParams: + """Unified container for multimodal parameters. + + This class encapsulates all multimodal-related data that flows through the system, + providing a clean interface for handling multimodal inputs across different models. + + Attributes: + multimodal_input: Multimodal input data with hashing information. + multimodal_data: Processed multimodal data containing embeddings, configurations, + and modality-specific data organized by type. + + Structure of multimodal_data: + { + "mrope_config": { + "mrope_rotary_cos_sin": torch.Tensor, # Rotary embeddings (Qwen2/2.5-VL) + "mrope_position_deltas": torch.Tensor, # Position deltas (Qwen2/2.5-VL) + }, + "multimodal_embedding": torch.Tensor, # Pre-computed vision embeddings + "image": { + "pixel_values": torch.Tensor, + "image_height": torch.Tensor | List[int], + "image_width": torch.Tensor | List[int], + }, + "video": { + "pixel_values": torch.Tensor, + "video_height": torch.Tensor | List[int], + "video_width": torch.Tensor | List[int], + }, + # ... other modalities + } + """ + + multimodal_input: Optional[MultimodalInput] = None + multimodal_data: Optional[Dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self): + """Ensure default values are properly set.""" + if self.multimodal_data is None: + self.multimodal_data = {} + + def to_device(self, + element: str, + device: str, + pin_memory: bool = False) -> None: + """Move specified multimodal data element to target device. + + Args: + element: Element to move ("multimodal_data" or "multimodal_input") + device: Target device (e.g., "cuda", "cpu") + pin_memory: Whether to pin memory for faster transfers + """ + + def _to_device( + input_tensor: Union[torch.Tensor, List, dict, None], + pin_memory: bool = False, + ) -> Union[torch.Tensor, List, dict, None]: + if input_tensor is None: + return None + elif isinstance(input_tensor, list): + return [_to_device(item, pin_memory) for item in input_tensor] + elif isinstance(input_tensor, dict): + return { + key: _to_device(value, pin_memory) + for key, value in input_tensor.items() + } + elif isinstance(input_tensor, torch.Tensor): + if pin_memory and input_tensor.device.type == 'cpu': + return input_tensor.pin_memory().to(device, + non_blocking=True) + else: + return input_tensor.to(device, non_blocking=True) + else: + return input_tensor + + if element == "multimodal_data": + self.multimodal_data = _to_device(self.multimodal_data, pin_memory) + elif element == "multimodal_input": + self.multimodal_input = _to_device(self.multimodal_input, + pin_memory) + else: + print( + f"MultimodalParams: Unsupported element '{element}' to move to device. " + f"Supported elements: 'multimodal_data', 'multimodal_input'") + + def strip_for_context(self) -> None: + """Strip multimodal data for context processing. + + Removes only mrope_position_deltas while keeping all other multimodal data + (embeddings, images, etc.) needed for context phase processing. + """ + if not (self.multimodal_data + and 'mrope_config' in self.multimodal_data): + return + + mrope_config = self.multimodal_data['mrope_config'] + if 'mrope_position_deltas' in mrope_config: + del mrope_config['mrope_position_deltas'] + + # Clean up empty mrope_config + if not mrope_config: + del self.multimodal_data['mrope_config'] + + def strip_for_generation(self) -> None: + """Strip multimodal data for generation processing. + + Keeps only mrope_position_deltas and removes all other multimodal data + (embeddings, images, etc.) as they're not needed during generation. + """ + if not self.multimodal_data: + return + + # Extract mrope_position_deltas before clearing + mrope_position_deltas = None + if 'mrope_config' in self.multimodal_data: + mrope_config = self.multimodal_data['mrope_config'] + if isinstance(mrope_config, + dict) and 'mrope_position_deltas' in mrope_config: + mrope_position_deltas = mrope_config['mrope_position_deltas'] + + # Clear all data and restore only position deltas if they exist + self.multimodal_data = {} + if mrope_position_deltas is not None: + self.multimodal_data['mrope_config'] = { + 'mrope_position_deltas': mrope_position_deltas + } + + def has_content(self) -> bool: + """Check if this object contains any multimodal data.""" + return bool(self.multimodal_input or self.multimodal_data) + + # adopt from vllm : https://github.com/vllm-project/vllm/blob/main/vllm/vllm/multimodal/hash.py def serialize_item(obj: object) -> bytes: # Simple cases diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 8a32a580af5..6f3adcbda2f 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase from tensorrt_llm.inputs.data import TextPrompt +from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.inputs.registry import DefaultInputProcessor from .._utils import nvtx_range_debug @@ -357,9 +358,8 @@ def generate_async( sampling_params.add_special_tokens = False query_token_ids = None - multimodal_input = None - multimodal_embedding = None - mrope_config = None + multimodal_params = None + if "prompt_token_ids" in inputs: # TODO: if specify prompt_token_ids, the mm hashing is not supported yet prompt_token_ids = inputs['prompt_token_ids'] @@ -384,11 +384,15 @@ def generate_async( prompt = inputs['prompt'] if extra_processed_inputs is not None: query_token_ids = extra_processed_inputs.get('query_token_ids') - multimodal_embedding = extra_processed_inputs.get( - 'mm_embedding') - mrope_config = extra_processed_inputs.get('mrope_config') - multimodal_input = extra_processed_inputs.get( - 'multimodal_input') + # Create unified MultimodalParams + multimodal_params = MultimodalParams( + multimodal_input=extra_processed_inputs.get( + 'multimodal_input'), + multimodal_data=extra_processed_inputs.get( + 'multimodal_data')) + # Only pass it if it has content + if not multimodal_params.has_content(): + multimodal_params = None else: raise TypeError( f"The inputs must be type str or list of int, but got {type(inputs)}" @@ -408,12 +412,10 @@ def generate_async( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, streaming=streaming, - multimodal_input=multimodal_input, - multimodal_embedding=multimodal_embedding, - mrope_config=mrope_config, kv_cache_retention_config=kv_cache_retention_config, disaggregated_params=disaggregated_params, postproc_params=_postproc_params, + multimodal_params=multimodal_params, ) return RequestOutput._from_generation_result(result, prompt, diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index 66ff8d59e17..171d6d5bb3f 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -188,7 +188,7 @@ def test_single_chat_session_video(client: openai.OpenAI, model_name: str): "type": "text", "text": content_text }, { - "type": "image_url", + "type": "video_url", "video_url": { "url": video_url }