Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _add_trt_llm_dll_directory():
from .builder import BuildConfig, Builder, BuilderConfig, build
from .disaggregated_params import DisaggregatedParams
from .functional import Tensor, constant
from .llmapi import LLM, LlmArgs
from .llmapi import LLM, LlmArgs, MultimodalEncoder
from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs
from .logger import logger
from .mapping import Mapping
Expand Down Expand Up @@ -103,6 +103,7 @@ def _add_trt_llm_dll_directory():
'quantization',
'tools',
'LLM',
'MultimodalEncoder',
'LlmArgs',
'TorchLlmArgs',
'TrtLlmArgs',
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class ModelConfig(Generic[TConfig]):

_frozen: bool = field(default=False, init=False, repr=False)

# If true, ONLY the vision encoder part of the full model is loaded/executed.
mm_encoder_only: bool = False

def __setattr__(self, key, value):
"""
Prevent modification of frozen instance attributes.
Expand All @@ -115,7 +118,8 @@ def __post_init__(self):
if self.pretrained_config and hasattr(self.pretrained_config,
"architectures"):
self.is_generation = self.is_generation_model(
self.pretrained_config.architectures)
self.pretrained_config.architectures,
mm_encoder_only=self.mm_encoder_only)

def get_all_reduce_strategy(strategy: str = "AUTO"):
maps = {
Expand Down Expand Up @@ -164,12 +168,15 @@ def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
raise ValueError(f'quant config of {name} is not found')

@staticmethod
def is_generation_model(model_architectures: Optional[List[str]]) -> bool:
def is_generation_model(model_architectures: Optional[List[str]],
mm_encoder_only: bool = False) -> bool:
if model_architectures is None:
logger.warning(
"Model architectures is None, default to is_generation_model=True"
)
return True
if mm_encoder_only:
return False
return model_architectures[0] not in [
"BertForSequenceClassification", "Qwen2ForProcessRewardModel",
"Qwen2ForRewardModel", "LlamaForTextEmbedding"
Expand Down
15 changes: 13 additions & 2 deletions tensorrt_llm/_torch/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from ..model_config import ModelConfig
from ..utils import model_extra_attrs
from .modeling_utils import (MODEL_CLASS_MAPPING, DecoderModelForCausalLM,
TConfig, TModel)
from .modeling_utils import (MODEL_CLASS_MAPPING,
MODEL_CLASS_VISION_ENCODER_MAPPING,
DecoderModelForCausalLM, TConfig, TModel)


class AutoModelForCausalLM(Generic[TModel, TConfig]):
Expand All @@ -13,6 +14,16 @@ def from_config(
config: ModelConfig[TConfig],
) -> DecoderModelForCausalLM[TModel, TConfig]:
model_arch = config.pretrained_config.architectures[0]
if config.mm_encoder_only:
vision_encoder_info = MODEL_CLASS_VISION_ENCODER_MAPPING.get(
model_arch)
if vision_encoder_info is None:
raise ValueError(
f"Unknown architecture for AutoModelForMultimodalEncoder: {model_arch}"
)
vision_encoder_cls, vlm_base_model = vision_encoder_info
return vision_encoder_cls(config, vlm_base_model)

# Hack to detect eagle3 checkpoints. TODO: should we provide
# our own checkpoints with the correct arch? It would let us
# avoid nasty stuff like this.
Expand Down
167 changes: 157 additions & 10 deletions tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import os
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -24,7 +24,8 @@
from .modeling_auto import AutoModelForCausalLM
from .modeling_clip import CLIPVisionModel
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_utils import ModelConfig, filter_weights, register_auto_model
from .modeling_utils import (filter_weights, register_auto_model,
register_vision_encoder)

DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'

Expand All @@ -51,6 +52,140 @@ def __init__(self,

self.image_token_index = model_config.image_token_index
self.vocab_size = model_config.vocab_size
self.config = model_config.vision_config

def get_num_tokens_per_image(
self,
*,
image_width: int,
image_height: int,
) -> int:
image_size = (image_height, image_width)
num_image_tokens = self.processor._get_num_multimodal_tokens(
[image_size])["num_image_tokens"][0]
return num_image_tokens

def _postprocess(
self, input_ids: torch.Tensor, mm_features: Union[torch.Tensor,
List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
# Define model specific variables here before shared logic
mm_tokens = torch.tensor([self.model_config.image_token_index
]).to(input_ids.device)
model_hidden_size = self.model_config.text_config.hidden_size
start_len = end_len = 0 # for llava, need not append start/end token around each image token
# End model specific variables

## find mm token positions in input_ids
mm_token_positions = torch.where(torch.isin(input_ids, mm_tokens))[0]
num_medias = num_mm_tokens = len(mm_token_positions)
if num_medias > 1 and isinstance(mm_features, torch.Tensor):
mm_features = list(
mm_features.split(mm_features.shape[0] // num_medias))

if isinstance(mm_features, torch.Tensor):
# 1 prompt + 1 media
# "split" means what a single mm_token in the input_ids should represent
# image: one split --> one frame
# video: one split --> N frames
num_frames, mm_feature_length, mm_hidden_dim = mm_features.shape
mm_lengths_per_split = [mm_feature_length * num_frames]
mm_lengths_per_frame = [mm_feature_length]
elif isinstance(mm_features, list):
# 1 prompt + N media
num_frames = len(mm_features) if mm_features[0].dim() == 2 else sum(
[f.shape[0] for f in mm_features])
mm_lengths_per_split = [
f.shape[0] if f.dim() == 2 else f.shape[0] * f.shape[1]
for f in mm_features
]
mm_lengths_per_frame = [
f.shape[0] if f.dim() == 2 else f.shape[1] for f in mm_features
]
mm_hidden_dim = mm_features[0].shape[-1]
mm_features = torch.cat(mm_features, dim=0)
else:
raise ValueError(
f"Invalid multimodal features type: {type(mm_features)}")
mm_total_length = sum(mm_lengths_per_split)
assert mm_hidden_dim == model_hidden_size, "Multimodal embedding_dim must match model hidden_size"

## split input_ids into segments by isolating mm tokens
mm_split_positions = torch.cat(
[mm_token_positions, mm_token_positions + 1]).unique()
input_ids_splits = list(input_ids.tensor_split(mm_split_positions.cpu(
))) # len(input_ids_splits) = num_segments after mm tokens are isolated
mm_ids_splits = list(
torch.arange(self.vocab_size,
self.vocab_size + mm_total_length,
device=input_ids.device).split(mm_lengths_per_split)
) # len(mm_ids_splits) = num_mm_segments

for i, mm_ids in enumerate(mm_ids_splits):
mm_ids = mm_ids.reshape(-1, mm_lengths_per_frame[i])
mm_ids_splits[i] = mm_ids.flatten()

## replace mm token ids with the expanded out-of-vocab ids
mm_split_idx = 0
for i, split in enumerate(input_ids_splits):
if torch.isin(split, mm_tokens).any().item():
input_ids_splits[i] = mm_ids_splits[mm_split_idx]
mm_split_idx += 1
assert mm_split_idx == len(
mm_ids_splits), "All mm_ids_splits should be consumed"

## concat text & mm input_ids, wrap mm feature in prompt tuning config
fused_input_ids = torch.cat(input_ids_splits).to(
device=input_ids.device)
fused_length = len(input_ids) + mm_total_length + num_frames * (
start_len + end_len) - num_medias
assert len(
fused_input_ids
) == fused_length, f"Fused input_ids length {len(fused_input_ids)} should match the sum of text and multimodal embedding lengths {fused_length}"

# [num_frames, feature_length, hidden_dim] -> [num_frames * feature_length, hidden_dim]
mm_features = mm_features.view(-1, mm_features.shape[-1])
return fused_input_ids, mm_features

def attach_multimodal_embeddings(
self, inputs: TextPrompt,
multimodal_embedding: Dict[str, List[torch.Tensor]],
sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Attach pre-processed multimodal embeddings into text token stream for LlavaNext model.
This method skips vision processing and works with externally provided embeddings.
It replaces/expands image placeholders in the text with appropriate tokens and prepares
the embeddings for model forward pass.
Args:
inputs: Text prompt containing image placeholders
multimodal_embedding: Dictionary containing pre-processed image embedding data
Returns:
Tuple of (token_ids, extra_processed_inputs) where:
- token_ids: List of processed token IDs with image placeholders
- extra_processed_inputs: Optional dictionary containing multimodal embeddings
"""
text_prompt = inputs.get("prompt")
if not text_prompt:
raise ValueError("Text prompt is required but not provided")

if not isinstance(multimodal_embedding, dict):
raise ValueError("multimodal_embedding must be a dictionary")

if 'image' not in multimodal_embedding:
raise ValueError(
"Only image modality is supported for external multimodal embedding"
)

input_ids = self.tokenizer(text_prompt,
return_tensors="pt").input_ids[0]
mm_features = multimodal_embedding['image']
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
multimodal_data = {}
multimodal_data["multimodal_embedding"] = mm_features
return fused_input_ids.to(torch.int32).tolist(), {
"multimodal_data": multimodal_data
}

@torch.inference_mode()
def __call__(
Expand Down Expand Up @@ -90,6 +225,7 @@ class LlavaNextVisionModel(nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
**kwargs) -> None:
super().__init__()
self.model_config = model_config
self.pretrained_config = model_config.pretrained_config
self.device = f"cuda:{model_config.mapping.rank}"
model_path = self.pretrained_config._name_or_path
Expand Down Expand Up @@ -132,7 +268,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
self.vision_tower = hf_vision_tower.to(self.device)
else:
vision_model_config = ModelConfig(
pretrained_config=model_config.pretrained_config.vision_config,
pretrained_config=self.pretrained_config.vision_config,
attn_backend="TRTLLM")
self.vision_tower = CLIPVisionModel(vision_model_config).to(
self.device).to(self.dtype)
Expand All @@ -142,8 +278,12 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
self.mm_projector = hf_mm_projector
self.image_newline = hf_image_newline
self.vision_feature_select_strategy = getattr(
model_config.pretrained_config, "vision_feature_select_strategy",
"default")
self.pretrained_config, "vision_feature_select_strategy", "default")

self.post_config()

def post_config(self):
self.config = self.pretrained_config.vision_config

# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L284
def pack_image_features(self,
Expand All @@ -157,12 +297,12 @@ def pack_image_features(self,
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.pretrained_config.vision_config.image_size // self.pretrained_config.vision_config.patch_size
height = width = self.config.image_size // self.config.patch_size

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.pretrained_config.image_grid_pinpoints,
self.pretrained_config.vision_config.image_size,
self.config.image_size,
)

if (np.prod(image_feature.shape) %
Expand Down Expand Up @@ -224,7 +364,7 @@ def forward(self, multimodal_params: List[MultimodalParams]):
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.pretrained_config.image_grid_pinpoints,
patch_size=self.pretrained_config.vision_config.image_size,
patch_size=self.config.image_size,
) for imsize in image_sizes
]

Expand Down Expand Up @@ -262,6 +402,7 @@ def forward(self, multimodal_params: List[MultimodalParams]):
return [image_features]


@register_vision_encoder(LlavaNextVisionModel)
@register_auto_model("LlavaNextForConditionalGeneration")
@register_input_processor(LlavaNextInputProcessor, model_type="llava_next")
class LlavaNextModel(PreTrainedModel):
Expand All @@ -287,7 +428,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
self.llm = AutoModelForCausalLM.from_config(llm_model_config)

self.model_config = model_config
self.vocab_size = config.vocab_size
self.model_dtype = getattr(config.text_config, "torch_dtype",
torch.float16)
logger.info(f"{self.dtype=} {self.model_dtype=}")
Expand Down Expand Up @@ -323,7 +463,14 @@ def forward(
mm_embeds = []
if len(multimodal_params) > 0:
if not DISAGG:
mm_embeds = self.mm_encoder.forward(multimodal_params)
if multimodal_params[0].multimodal_data.get(
"multimodal_embedding", None) is not None:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
else:
mm_embeds = self.mm_encoder.forward(multimodal_params)
else:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
Expand Down
Loading