Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _add_trt_llm_dll_directory():
from .builder import BuildConfig, Builder, BuilderConfig, build
from .disaggregated_params import DisaggregatedParams
from .functional import Tensor, constant
from .llmapi import LLM, LlmArgs
from .llmapi import LLM, MultimodalEncoder
from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs
from .logger import logger
from .mapping import Mapping
Expand Down Expand Up @@ -105,6 +105,7 @@ def _add_trt_llm_dll_directory():
'quantization',
'tools',
'LLM',
'MultimodalEncoder',
'LlmArgs',
'TorchLlmArgs',
'TrtLlmArgs',
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class ModelConfig(Generic[TConfig]):

_frozen: bool = field(default=False, init=False, repr=False)

# If true, ONLY the vision encoder part of the full model is loaded/executed.
mm_encoder_only: bool = False

def __setattr__(self, key, value):
"""
Prevent modification of frozen instance attributes.
Expand All @@ -122,7 +125,8 @@ def __post_init__(self):
if self.pretrained_config and hasattr(self.pretrained_config,
"architectures"):
self.is_generation = self.is_generation_model(
self.pretrained_config.architectures)
self.pretrained_config.architectures,
mm_encoder_only=self.mm_encoder_only)

def get_all_reduce_strategy(strategy: str = "AUTO"):
maps = {
Expand Down Expand Up @@ -181,12 +185,15 @@ def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
raise ValueError(f'quant config of {name} is not found')

@staticmethod
def is_generation_model(model_architectures: Optional[List[str]]) -> bool:
def is_generation_model(model_architectures: Optional[List[str]],
mm_encoder_only: bool = False) -> bool:
if model_architectures is None:
logger.warning(
"Model architectures is None, default to is_generation_model=True"
)
return True
if mm_encoder_only:
return False
return model_architectures[0] not in [
"BertForSequenceClassification", "Qwen2ForProcessRewardModel",
"Qwen2ForRewardModel", "LlamaForTextEmbedding"
Expand Down
15 changes: 13 additions & 2 deletions tensorrt_llm/_torch/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from ..model_config import ModelConfig
from ..utils import model_extra_attrs
from .modeling_utils import (MODEL_CLASS_MAPPING, DecoderModelForCausalLM,
TConfig, TModel)
from .modeling_utils import (MODEL_CLASS_MAPPING,
MODEL_CLASS_VISION_ENCODER_MAPPING,
DecoderModelForCausalLM, TConfig, TModel)


class AutoModelForCausalLM(Generic[TModel, TConfig]):
Expand All @@ -13,6 +14,16 @@ def from_config(
config: ModelConfig[TConfig],
) -> DecoderModelForCausalLM[TModel, TConfig]:
model_arch = config.pretrained_config.architectures[0]
if config.mm_encoder_only:
vision_encoder_info = MODEL_CLASS_VISION_ENCODER_MAPPING.get(
model_arch)
if vision_encoder_info is None:
raise ValueError(
f"Unknown architecture for AutoModelForMultimodalEncoder: {model_arch}"
)
vision_encoder_cls, vlm_base_model = vision_encoder_info
return vision_encoder_cls(config, vlm_base_model)

# Hack to detect eagle3 checkpoints. TODO: should we provide
# our own checkpoints with the correct arch? It would let us
# avoid nasty stuff like this.
Expand Down
167 changes: 157 additions & 10 deletions tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import os
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -26,7 +26,8 @@
from .modeling_auto import AutoModelForCausalLM
from .modeling_clip import CLIPVisionModel
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_utils import ModelConfig, filter_weights, register_auto_model
from .modeling_utils import (filter_weights, register_auto_model,
register_vision_encoder)

DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'

Expand All @@ -53,6 +54,140 @@ def __init__(self,

self.image_token_index = model_config.image_token_index
self.vocab_size = model_config.vocab_size
self.config = model_config.vision_config

def get_num_tokens_per_image(
self,
*,
image_width: int,
image_height: int,
) -> int:
image_size = (image_height, image_width)
num_image_tokens = self.processor._get_num_multimodal_tokens(
[image_size])["num_image_tokens"][0]
return num_image_tokens

def _postprocess(
self, input_ids: torch.Tensor, mm_features: Union[torch.Tensor,
List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
# Define model specific variables here before shared logic
mm_tokens = torch.tensor([self.model_config.image_token_index
]).to(input_ids.device)
model_hidden_size = self.model_config.text_config.hidden_size
start_len = end_len = 0 # for llava, need not append start/end token around each image token
# End model specific variables

## find mm token positions in input_ids
mm_token_positions = torch.where(torch.isin(input_ids, mm_tokens))[0]
num_medias = num_mm_tokens = len(mm_token_positions)
if num_medias > 1 and isinstance(mm_features, torch.Tensor):
mm_features = list(
mm_features.split(mm_features.shape[0] // num_medias))

if isinstance(mm_features, torch.Tensor):
# 1 prompt + 1 media
# "split" means what a single mm_token in the input_ids should represent
# image: one split --> one frame
# video: one split --> N frames
num_frames, mm_feature_length, mm_hidden_dim = mm_features.shape
mm_lengths_per_split = [mm_feature_length * num_frames]
mm_lengths_per_frame = [mm_feature_length]
elif isinstance(mm_features, list):
# 1 prompt + N media
num_frames = len(mm_features) if mm_features[0].dim() == 2 else sum(
[f.shape[0] for f in mm_features])
mm_lengths_per_split = [
f.shape[0] if f.dim() == 2 else f.shape[0] * f.shape[1]
for f in mm_features
]
mm_lengths_per_frame = [
f.shape[0] if f.dim() == 2 else f.shape[1] for f in mm_features
]
mm_hidden_dim = mm_features[0].shape[-1]
mm_features = torch.cat(mm_features, dim=0)
else:
raise ValueError(
f"Invalid multimodal features type: {type(mm_features)}")
mm_total_length = sum(mm_lengths_per_split)
assert mm_hidden_dim == model_hidden_size, "Multimodal embedding_dim must match model hidden_size"

## split input_ids into segments by isolating mm tokens
mm_split_positions = torch.cat(
[mm_token_positions, mm_token_positions + 1]).unique()
input_ids_splits = list(input_ids.tensor_split(mm_split_positions.cpu(
))) # len(input_ids_splits) = num_segments after mm tokens are isolated
mm_ids_splits = list(
torch.arange(self.vocab_size,
self.vocab_size + mm_total_length,
device=input_ids.device).split(mm_lengths_per_split)
) # len(mm_ids_splits) = num_mm_segments

for i, mm_ids in enumerate(mm_ids_splits):
mm_ids = mm_ids.reshape(-1, mm_lengths_per_frame[i])
mm_ids_splits[i] = mm_ids.flatten()

## replace mm token ids with the expanded out-of-vocab ids
mm_split_idx = 0
for i, split in enumerate(input_ids_splits):
if torch.isin(split, mm_tokens).any().item():
input_ids_splits[i] = mm_ids_splits[mm_split_idx]
mm_split_idx += 1
assert mm_split_idx == len(
mm_ids_splits), "All mm_ids_splits should be consumed"

## concat text & mm input_ids, wrap mm feature in prompt tuning config
fused_input_ids = torch.cat(input_ids_splits).to(
device=input_ids.device)
fused_length = len(input_ids) + mm_total_length + num_frames * (
start_len + end_len) - num_medias
assert len(
fused_input_ids
) == fused_length, f"Fused input_ids length {len(fused_input_ids)} should match the sum of text and multimodal embedding lengths {fused_length}"

# [num_frames, feature_length, hidden_dim] -> [num_frames * feature_length, hidden_dim]
mm_features = mm_features.view(-1, mm_features.shape[-1])
return fused_input_ids, mm_features

def attach_multimodal_embeddings(
self, inputs: TextPrompt,
multimodal_embedding: Dict[str, List[torch.Tensor]],
sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Attach pre-processed multimodal embeddings into text token stream for LlavaNext model.
This method skips vision processing and works with externally provided embeddings.
It replaces/expands image placeholders in the text with appropriate tokens and prepares
the embeddings for model forward pass.
Args:
inputs: Text prompt containing image placeholders
multimodal_embedding: Dictionary containing pre-processed image embedding data
Returns:
Tuple of (token_ids, extra_processed_inputs) where:
- token_ids: List of processed token IDs with image placeholders
- extra_processed_inputs: Optional dictionary containing multimodal embeddings
"""
text_prompt = inputs.get("prompt")
if not text_prompt:
raise ValueError("Text prompt is required but not provided")

if not isinstance(multimodal_embedding, dict):
raise ValueError("multimodal_embedding must be a dictionary")

if 'image' not in multimodal_embedding:
raise ValueError(
"Only image modality is supported for external multimodal embedding"
)

input_ids = self.tokenizer(text_prompt,
return_tensors="pt").input_ids[0]
mm_features = multimodal_embedding['image']
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
multimodal_data = {}
multimodal_data["multimodal_embedding"] = mm_features
return fused_input_ids.to(torch.int32).tolist(), {
"multimodal_data": multimodal_data
}

@torch.inference_mode()
def __call__(
Expand Down Expand Up @@ -92,6 +227,7 @@ class LlavaNextVisionModel(nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
**kwargs) -> None:
super().__init__()
self.model_config = model_config
self.pretrained_config = model_config.pretrained_config
self.device = f"cuda:{model_config.mapping.rank}"
model_path = self.pretrained_config._name_or_path
Expand Down Expand Up @@ -134,7 +270,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
self.vision_tower = hf_vision_tower.to(self.device)
else:
vision_model_config = ModelConfig(
pretrained_config=model_config.pretrained_config.vision_config,
pretrained_config=self.pretrained_config.vision_config,
attn_backend="TRTLLM")
self.vision_tower = CLIPVisionModel(vision_model_config).to(
self.device).to(self.dtype)
Expand All @@ -144,8 +280,12 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
self.mm_projector = hf_mm_projector
self.image_newline = hf_image_newline
self.vision_feature_select_strategy = getattr(
model_config.pretrained_config, "vision_feature_select_strategy",
"default")
self.pretrained_config, "vision_feature_select_strategy", "default")

self.post_config()

def post_config(self):
self.config = self.pretrained_config.vision_config

# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L284
def pack_image_features(self,
Expand All @@ -159,12 +299,12 @@ def pack_image_features(self,
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.pretrained_config.vision_config.image_size // self.pretrained_config.vision_config.patch_size
height = width = self.config.image_size // self.config.patch_size

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.pretrained_config.image_grid_pinpoints,
self.pretrained_config.vision_config.image_size,
self.config.image_size,
)

if (np.prod(image_feature.shape) %
Expand Down Expand Up @@ -226,7 +366,7 @@ def forward(self, multimodal_params: List[MultimodalParams]):
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.pretrained_config.image_grid_pinpoints,
patch_size=self.pretrained_config.vision_config.image_size,
patch_size=self.config.image_size,
) for imsize in image_sizes
]

Expand Down Expand Up @@ -264,6 +404,7 @@ def forward(self, multimodal_params: List[MultimodalParams]):
return [image_features]


@register_vision_encoder(LlavaNextVisionModel)
@register_auto_model("LlavaNextForConditionalGeneration")
@register_input_processor(
LlavaNextInputProcessor,
Expand Down Expand Up @@ -295,7 +436,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
self.llm = AutoModelForCausalLM.from_config(llm_model_config)

self.model_config = model_config
self.vocab_size = config.vocab_size
self.model_dtype = getattr(config.text_config, "torch_dtype",
torch.float16)
logger.info(f"{self.dtype=} {self.model_dtype=}")
Expand Down Expand Up @@ -331,7 +471,14 @@ def forward(
mm_embeds = []
if len(multimodal_params) > 0:
if not DISAGG:
mm_embeds = self.mm_encoder.forward(multimodal_params)
if multimodal_params[0].multimodal_data.get(
"multimodal_embedding", None) is not None:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
else:
mm_embeds = self.mm_encoder.forward(multimodal_params)
else:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
Expand Down
Loading