diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 07b1ced5ca42..6c3f5b5998e9 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -699,6 +699,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I+ | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
+| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I+ | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + IE+ | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ |
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index dfca7d5c9c9a..df205a67d9f0 100755
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -15,7 +15,7 @@
from typing import NamedTuple
from huggingface_hub import snapshot_download
-from transformers import AutoTokenizer
+from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
@@ -875,6 +875,37 @@ def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData:
)
+def run_lfm2_vl(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "LiquidAI/LFM2-VL-450M"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ processor = AutoProcessor.from_pretrained(model_name)
+ messages = [
+ [
+ {
+ "role": "user",
+ "content": [{"type": "image"}, {"type": "text", "text": question}],
+ }
+ ]
+ for question in questions
+ ]
+ prompts = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1849,6 +1880,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"keye_vl1_5": run_keye_vl1_5,
"kimi_vl": run_kimi_vl,
"lightonocr": run_lightonocr,
+ "lfm2_vl": run_lfm2_vl,
"llama4": run_llama4,
"llava": run_llava,
"llava-next": run_llava_next,
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 570bcc734146..ff91d2391fe8 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -732,6 +732,10 @@ def check_available_online(
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
"lightonai/LightOnOCR-1B-1025"
),
+ "Lfm2VlForConditionalGeneration": _HfExamplesInfo(
+ "LiquidAI/LFM2-VL-450M",
+ min_transformers_version="5.0.0",
+ ),
"Llama4ForConditionalGeneration": _HfExamplesInfo(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
max_model_len=10240,
diff --git a/vllm/model_executor/models/lfm2_vl.py b/vllm/model_executor/models/lfm2_vl.py
new file mode 100644
index 000000000000..d87b23d00cba
--- /dev/null
+++ b/vllm/model_executor/models/lfm2_vl.py
@@ -0,0 +1,732 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import itertools
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Annotated, Literal
+
+import torch
+import torch.nn as nn
+from transformers import BatchFeature
+from transformers.activations import ACT2FN
+from transformers.models.lfm2_vl import Lfm2VlProcessor
+from transformers.models.lfm2_vl.configuration_lfm2_vl import Lfm2VlConfig
+from transformers.models.lfm2_vl.image_processing_lfm2_vl_fast import (
+ Lfm2VlImageProcessorFast,
+ find_closest_aspect_ratio,
+ round_by_factor,
+)
+
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.forward_context import set_forward_context
+from vllm.model_executor.layers.mamba.mamba_utils import (
+ MambaStateDtypeCalculator,
+ MambaStateShapeCalculator,
+)
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalDataDict,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdateDetails,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import (
+ IsHybrid,
+ MultiModalEmbeddings,
+ SupportsLoRA,
+ SupportsMultiModal,
+ SupportsPP,
+)
+from .siglip2 import Siglip2Model
+from .utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+
+
+class Lfm2VLImagePixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - b: Number of images in the prompt
+ - bn: Batch size * number of images
+ - d: Number of dimensions
+ - fd: Number of features per dimension
+ """
+
+ type: Literal["pixel_values"] = "pixel_values"
+ pixel_values: Annotated[torch.Tensor, TensorShape("bn", "d", "fd")]
+ spatial_shapes: Annotated[torch.Tensor, TensorShape("bn", 2)]
+ num_patches: Annotated[torch.Tensor, TensorShape("b")]
+
+
+LFM2VLImageInputs = Lfm2VLImagePixelInputs
+
+
+class Lfm2VLProcessingInfo(BaseProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Lfm2VlConfig)
+
+ def get_hf_processor(self, **kwargs):
+ return self.ctx.get_hf_processor(Lfm2VlProcessor, **kwargs)
+
+ def get_image_processor(self, **kwargs: object) -> Lfm2VlImageProcessorFast:
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ processor = self.get_image_processor()
+ max_image_tokens = processor.max_image_tokens
+ encoder_patch_size = processor.encoder_patch_size
+ downsample_factor = processor.downsample_factor
+ max_pixels = max_image_tokens * (encoder_patch_size**2) * (downsample_factor**2)
+ side = int(math.sqrt(max_pixels))
+ return ImageSize(width=side, height=side)
+
+ def _is_image_too_large(
+ self,
+ height: int,
+ width: int,
+ max_image_tokens: int,
+ encoder_patch_size: int,
+ downsample_factor: int,
+ max_pixels_tolerance: float,
+ ) -> bool:
+ """Check if the image is too large to be processed as one tile."""
+ total_factor = encoder_patch_size * downsample_factor
+
+ h_bar = max(encoder_patch_size, round_by_factor(height, total_factor))
+ w_bar = max(encoder_patch_size, round_by_factor(width, total_factor))
+ return (
+ h_bar * w_bar
+ > max_image_tokens
+ * encoder_patch_size**2
+ * downsample_factor**2
+ * max_pixels_tolerance
+ )
+
+ def smart_resize(
+ self,
+ height: int,
+ width: int,
+ downsample_factor: int,
+ min_image_tokens: int,
+ max_image_tokens: int,
+ encoder_patch_size: int,
+ ) -> tuple[int, int]:
+ total_factor = encoder_patch_size * downsample_factor
+ smart_resize_min_pixels = (
+ min_image_tokens * encoder_patch_size**2 * downsample_factor**2
+ )
+ smart_resize_max_pixels = (
+ max_image_tokens * encoder_patch_size**2 * downsample_factor**2
+ )
+
+ h_bar = max(total_factor, round_by_factor(height, total_factor))
+ w_bar = max(total_factor, round_by_factor(width, total_factor))
+
+ if h_bar * w_bar > smart_resize_max_pixels:
+ beta = math.sqrt((height * width) / smart_resize_max_pixels)
+ h_bar = max(
+ total_factor, math.floor(height / beta / total_factor) * total_factor
+ )
+ w_bar = max(
+ total_factor, math.floor(width / beta / total_factor) * total_factor
+ )
+ elif h_bar * w_bar < smart_resize_min_pixels:
+ beta = math.sqrt(smart_resize_min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / total_factor) * total_factor
+ w_bar = math.ceil(width * beta / total_factor) * total_factor
+
+ return w_bar, h_bar
+
+ def _target_ratios(self, min_tiles: int, max_tiles: int) -> list[tuple[int, int]]:
+ ratios = [
+ (w, h)
+ for n in range(min_tiles, max_tiles + 1)
+ for w in range(1, n + 1)
+ for h in range(1, n + 1)
+ if min_tiles <= w * h <= max_tiles
+ ]
+ return sorted(set(ratios), key=lambda x: x[0] * x[1])
+
+ def _get_grid_layout(
+ self,
+ height: int,
+ width: int,
+ min_tiles: int,
+ max_tiles: int,
+ tile_size: int,
+ ) -> tuple[int, int]:
+ aspect_ratio = width / height
+ target_ratios = self._target_ratios(min_tiles, max_tiles)
+ # find best matching grid configuration
+ grid_width, grid_height = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, width, height, tile_size
+ )
+ total_patches = grid_width * grid_height
+ return grid_width, grid_height, total_patches
+
+ def _get_image_feature_grid_size(
+ self,
+ image_width: int,
+ image_height: int,
+ processor: Lfm2VlProcessor | None,
+ ) -> tuple[int, int]:
+ if processor is None:
+ processor = self.get_image_processor()
+
+ downsample_factor = processor.image_processor.downsample_factor
+ encoder_patch_size = processor.image_processor.encoder_patch_size
+ max_pixels_tolerance = processor.image_processor.max_pixels_tolerance
+ min_tiles = processor.image_processor.min_tiles
+ max_tiles = processor.image_processor.max_tiles
+ max_image_tokens = processor.image_processor.max_image_tokens
+ tile_size = processor.image_processor.tile_size
+
+ do_image_splitting = not min_tiles == max_tiles == 1
+ is_image_large = self._is_image_too_large(
+ height=image_height,
+ width=image_width,
+ max_image_tokens=max_image_tokens,
+ encoder_patch_size=encoder_patch_size,
+ downsample_factor=downsample_factor,
+ max_pixels_tolerance=max_pixels_tolerance,
+ )
+
+ # Big image will be cropped into patches and small images are just resized
+ if is_image_large and do_image_splitting:
+ grid_width, grid_height, total_patches = self._get_grid_layout(
+ image_height,
+ image_width,
+ min_tiles=min_tiles,
+ max_tiles=max_tiles,
+ tile_size=tile_size,
+ )
+ else:
+ grid_width = grid_height = total_patches = 1
+
+ if grid_width * grid_height != 1: # Thumbnail
+ total_patches += 1
+
+ return grid_width, grid_height, total_patches
+
+ def get_num_patches(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Lfm2VlProcessor | None,
+ ) -> int:
+ _, _, total_patches = self._get_image_feature_grid_size(
+ image_width=image_width,
+ image_height=image_height,
+ processor=processor,
+ )
+ return total_patches
+
+ def get_image_repl(
+ self,
+ image_width: int,
+ image_height: int,
+ spatial_shapes: torch.Tensor,
+ processor: Lfm2VlProcessor | None,
+ ) -> str:
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ grid_placeholder = "<|img_row_{n_h}_col_{n_w}|>"
+ image_token = processor.image_token
+ image_start_token = processor.image_start_token
+ image_end_token = processor.image_end_token
+ image_thumbnail_token = processor.image_thumbnail_token
+
+ num_thumbnail_tokens, num_tokens_per_tile = self.get_num_image_tokens(
+ spatial_shapes=spatial_shapes,
+ processor=processor,
+ )
+ tile_img_placeholder = grid_placeholder + (image_token * num_tokens_per_tile)
+
+ grid_w, grid_h, _ = self._get_image_feature_grid_size(
+ image_width=image_width,
+ image_height=image_height,
+ processor=processor,
+ )
+
+ if grid_w > 1 or grid_h > 1:
+ tiles_placeholder: list[str] = [
+ tile_img_placeholder.format(n_h=i + 1, n_w=j + 1)
+ for i in range(grid_h)
+ for j in range(grid_w)
+ ]
+
+ if num_thumbnail_tokens > 0:
+ tiles_placeholder.append(
+ image_thumbnail_token + (image_token * num_thumbnail_tokens)
+ )
+ else:
+ tiles_placeholder = [image_token * num_thumbnail_tokens]
+
+ placeholder = "".join(
+ itertools.chain([image_start_token], tiles_placeholder, [image_end_token])
+ )
+ return placeholder
+
+ def get_num_image_tokens(
+ self,
+ *,
+ spatial_shapes: torch.Tensor,
+ processor: Lfm2VlProcessor | None,
+ ) -> tuple[int, int]:
+ tile_size = processor.image_processor.tile_size
+ downsample_factor = processor.image_processor.downsample_factor
+ encoder_patch_size = processor.image_processor.encoder_patch_size
+ num_thumbnail_tokens = spatial_shapes[-1].prod() // (downsample_factor**2)
+ num_patches_tile = tile_size // encoder_patch_size
+ dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor)
+ num_tiles_tokens = dwn_num_patches_tile * dwn_num_patches_tile
+ return num_thumbnail_tokens, num_tiles_tokens
+
+
+class Lfm2VLDummyInputsBuilder(BaseDummyInputsBuilder[Lfm2VLProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ processor = self.info.get_hf_processor()
+ image_token = processor.image_token
+ return image_token * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+
+ target_width, target_height = self.info.get_image_size_with_most_features()
+
+ image_overrides = mm_options.get("image") if mm_options else None
+
+ return {
+ "image": self._get_dummy_images(
+ width=target_width,
+ height=target_height,
+ num_images=num_images,
+ overrides=image_overrides,
+ ),
+ }
+
+
+class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ # Text-only input not supported in composite processor
+ if not (images := mm_data.get("images", [])):
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ processed_outputs = super()._call_hf_processor(
+ prompt,
+ mm_data,
+ mm_kwargs,
+ tok_kwargs,
+ )
+
+ parsed_images = (
+ self._get_data_parser()
+ .parse_mm_data({"image": images})
+ .get_items("image", ImageProcessorItems)
+ )
+ image_sizes = [
+ parsed_images.get_image_size(i) for i in range(len(parsed_images))
+ ]
+ hf_processor = self.info.get_hf_processor(**mm_kwargs)
+
+ num_patches = [
+ self.info.get_num_patches(
+ image_width=size.width,
+ image_height=size.height,
+ processor=hf_processor,
+ )
+ for size in image_sizes
+ ]
+ processed_outputs["num_patches"] = torch.tensor(num_patches)
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ num_patches = hf_inputs.get("num_patches", torch.empty(0))
+
+ return dict[str, MultiModalFieldConfig](
+ pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
+ spatial_shapes=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_patches, keep_on_cpu=True
+ ),
+ num_patches=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptReplacement]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_token = hf_processor.image_token
+
+ def get_image_replacement_lfm2vl(item_idx: int):
+ images = mm_items.get_items("image", ImageProcessorItems)
+ image_size = images.get_image_size(item_idx)
+ out_item = out_mm_kwargs["image"][item_idx]
+ spatial_shapes = out_item["spatial_shapes"].data
+ assert isinstance(spatial_shapes, torch.Tensor)
+ image_repl = self.info.get_image_repl(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ spatial_shapes=spatial_shapes,
+ processor=hf_processor,
+ )
+ return PromptUpdateDetails.select_text(
+ image_repl,
+ embed_text=image_token,
+ )
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=get_image_replacement_lfm2vl,
+ )
+ ]
+
+
+class Lfm2VLMultiModalProjector(nn.Module):
+ def __init__(
+ self, config: Lfm2VlConfig, use_data_parallel: bool = False, prefix: str = ""
+ ):
+ super().__init__()
+ self.use_data_parallel = use_data_parallel
+
+ in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
+ self.factor = config.downsample_factor
+ self.projector_use_layernorm = config.projector_use_layernorm
+ if self.projector_use_layernorm:
+ self.layer_norm = nn.LayerNorm(in_channels)
+ self.linear_1 = nn.Linear(
+ in_channels,
+ config.projector_hidden_size,
+ bias=config.projector_bias,
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(
+ config.projector_hidden_size,
+ config.text_config.hidden_size,
+ bias=config.projector_bias,
+ )
+
+ def forward(self, image_features: torch.Tensor):
+ image_features = self.pixel_unshuffle(image_features)
+ if self.projector_use_layernorm:
+ image_features = self.layer_norm(image_features)
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+ def pixel_unshuffle(self, hidden_states: torch.Tensor):
+ batch_size, width, height, channels = hidden_states.size()
+ hidden_states = hidden_states.reshape(
+ batch_size, width, height // self.factor, channels * self.factor
+ )
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(
+ batch_size,
+ height // self.factor,
+ width // self.factor,
+ channels * self.factor**2,
+ )
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ return hidden_states
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Lfm2VLMultiModalProcessor,
+ info=Lfm2VLProcessingInfo,
+ dummy_inputs=Lfm2VLDummyInputsBuilder,
+)
+class Lfm2VLForConditionalGeneration(
+ nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, IsHybrid
+):
+ merge_by_field_config = True
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "lm_head.": "language_model.lm_head.",
+ "model.language_model.": "language_model.model.",
+ "model.vision_tower.": "vision_tower.",
+ "model.multi_modal_projector.": "multi_modal_projector.",
+ }
+ )
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return ""
+
+ raise ValueError("Only image modality is supported")
+
+ @classmethod
+ def get_mamba_state_dtype_from_config(
+ cls,
+ vllm_config: "VllmConfig",
+ ) -> tuple[torch.dtype, ...]:
+ return MambaStateDtypeCalculator.short_conv_state_dtype(
+ vllm_config.model_config.dtype,
+ vllm_config.cache_config.mamba_cache_dtype,
+ )
+
+ @classmethod
+ def get_mamba_state_shape_from_config(
+ cls,
+ vllm_config: "VllmConfig",
+ ) -> tuple[tuple[int, int]]:
+ """Calculate shapes for LFM2's convolutional cache.
+
+ Args:
+ vllm_config: vLLM config
+
+ Returns:
+ Tuple containing:
+ - conv_state_shape: Shape for convolutional state cache
+ """
+ parallel_config = vllm_config.parallel_config
+ hf_language_config = vllm_config.model_config.hf_config.text_config
+
+ return MambaStateShapeCalculator.short_conv_state_shape(
+ tp_world_size=parallel_config.tensor_parallel_size,
+ intermediate_size=hf_language_config.hidden_size,
+ conv_kernel=hf_language_config.conv_L_cache,
+ )
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
+ super().__init__()
+ config: Lfm2VlConfig = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ vision_config = config.vision_config
+ quant_config = vllm_config.quant_config
+
+ self.config = config
+ self.vllm_config = vllm_config
+ self.multimodal_config = multimodal_config
+ self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
+
+ if vision_config.model_type == "siglip2_vision_model":
+ self.vision_tower = Siglip2Model(
+ config=vision_config,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=maybe_prefix(prefix, "vision_tower"),
+ )
+ else:
+ raise ValueError(
+ f"Unsupported visual tokenizer model_type: {vision_config.model_type}"
+ )
+
+ self.multi_modal_projector = Lfm2VLMultiModalProjector(
+ config=config,
+ use_data_parallel=self.use_data_parallel,
+ prefix=f"{prefix}.multi_modal_projector",
+ )
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language"),
+ architectures=config.text_config.architectures,
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object
+ ) -> LFM2VLImageInputs | None:
+ pixel_values = kwargs.pop("pixel_values", None)
+ spatial_shapes = kwargs.pop("spatial_shapes", None)
+ num_patches = kwargs.pop("num_patches", None)
+ if pixel_values is None:
+ return None
+
+ return LFM2VLImageInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ spatial_shapes=spatial_shapes,
+ num_patches=num_patches,
+ )
+
+ def image_pixels_to_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ spatial_shapes: torch.Tensor,
+ ) -> torch.Tensor:
+ pixel_values = pixel_values.to(
+ dtype=self.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
+ ) # fp16 compatibility
+
+ # LFM2-VL's HF processor pads patch sequences with trailing zeros.
+ # Derive the valid-patch mask from spatial_shapes instead of carrying
+ # pixel_attention_mask through the vLLM multimodal pipeline.
+ max_seq_len = pixel_values.shape[1]
+ lengths_cpu = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(
+ dtype=torch.int32
+ )
+ max_seqlen = (
+ lengths_cpu.max().reshape(1).to(device=pixel_values.device)
+ if lengths_cpu.numel()
+ else torch.tensor([0], dtype=torch.int32, device=pixel_values.device)
+ )
+ lengths = lengths_cpu.to(device=pixel_values.device)
+ packed_mask = (
+ torch.arange(max_seq_len, device=pixel_values.device)[None, :]
+ < lengths[:, None]
+ )
+ cu_seqlens = torch.zeros(
+ lengths.shape[0] + 1,
+ dtype=torch.int32,
+ device=lengths.device,
+ )
+ cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
+
+ with set_forward_context(None, self.vllm_config):
+ vision_outputs = self.vision_tower(
+ pixel_values=pixel_values,
+ spatial_shapes=spatial_shapes,
+ packed_mask=packed_mask,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ image_outputs = getattr(vision_outputs, "last_hidden_state", vision_outputs)
+
+ image_features = []
+
+ # spatial_shapes is on CPU (keep_on_cpu=True), so .tolist() is instant
+ spatial_shapes_list = spatial_shapes.tolist()
+ for img_idx, (feature_org_h, feature_org_w) in enumerate(spatial_shapes_list):
+ feature_len = feature_org_h * feature_org_w
+ feature = image_outputs[img_idx, :feature_len]
+
+ # reshape to original height and width
+ feature = feature.reshape(1, feature_org_h, feature_org_w, -1)
+
+ # project the image representation
+ img_embedding = self.multi_modal_projector(feature)
+
+ # flatten here to handle variable length in naflex
+ img_embedding = img_embedding.reshape(-1, img_embedding.size(-1))
+ image_features.append(img_embedding)
+
+ return image_features
+
+ def _process_image_input(
+ self,
+ image_input: LFM2VLImageInputs,
+ ) -> torch.Tensor | list[torch.Tensor]:
+ pixel_values = image_input["pixel_values"]
+ spatial_shapes = image_input["spatial_shapes"]
+ num_patches = image_input["num_patches"]
+
+ image_features = self.image_pixels_to_features(
+ pixel_values,
+ spatial_shapes=spatial_shapes,
+ )
+
+ # Group patches by image - num_patches is on CPU (keep_on_cpu=True)
+ # so .tolist() is instant with no DtoH sync
+ num_patches_list = num_patches.tolist()
+ batched_features: list[torch.Tensor] = []
+ patch_idx = 0
+ for count in num_patches_list:
+ # Slice the list of patch tensors for this image
+ image_patches = image_features[patch_idx : patch_idx + count]
+ # Concatenate patches for this image
+ batched_features.append(torch.cat(image_patches, dim=0))
+ patch_idx += count
+
+ return batched_features
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return []
+
+ return self._process_image_input(image_input)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs: object,
+ ) -> torch.Tensor | IntermediateTensors:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ logits = self.language_model.compute_logits(hidden_states)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="multi_modal_projector",
+ tower_model="vision_tower",
+ )
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index a25267fc2267..2db1598cadb0 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -348,6 +348,7 @@
"lightonocr",
"LightOnOCRForConditionalGeneration",
),
+ "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
diff --git a/vllm/model_executor/models/siglip2.py b/vllm/model_executor/models/siglip2.py
new file mode 100644
index 000000000000..f7c91aa28dcc
--- /dev/null
+++ b/vllm/model_executor/models/siglip2.py
@@ -0,0 +1,495 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Implementation of Siglip2VisionModel intended to be only used
+within a vision language model."""
+
+from collections.abc import Iterable
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from transformers import Siglip2VisionConfig
+
+from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import MultiModalConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+from .vision import should_torch_compile_mm_vit
+
+
+class Siglip2VisionEmbeddings(nn.Module):
+ def __init__(self, config: Siglip2VisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+ self.patch_embedding = nn.Linear(
+ in_features=config.num_channels * self.patch_size * self.patch_size,
+ out_features=self.embed_dim,
+ )
+ self.num_patches = config.num_patches
+ self.position_embedding_size = int(self.num_patches**0.5)
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
+
+ @staticmethod
+ def resize_positional_embeddings(
+ positional_embeddings: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ max_length: int,
+ ) -> torch.Tensor:
+ """
+ Resize positional embeddings to image-specific size and pad to a fixed size.
+
+ Args:
+ positional_embeddings (`torch.Tensor`):
+ Position embeddings of shape (height, width, embed_dim)
+ spatial_shapes (`torch.LongTensor`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional
+ embeddings to
+ max_length (`int`):
+ Maximum length of the positional embeddings to pad resized
+ positional embeddings to
+
+ Returns:
+ `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
+ """
+ batch_size = spatial_shapes.shape[0]
+ embed_dim = positional_embeddings.shape[-1]
+ source_dtype = positional_embeddings.dtype
+
+ resulted_positional_embeddings = torch.empty(
+ (batch_size, max_length, embed_dim),
+ device=positional_embeddings.device,
+ dtype=source_dtype,
+ )
+
+ # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
+ positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
+
+ # Upcast to float32 on CPU because antialias is not supported for
+ # bfloat16/float16 on CPU
+ if positional_embeddings.device.type == "cpu":
+ positional_embeddings = positional_embeddings.to(torch.float32)
+
+ for i in range(batch_size):
+ # (1, dim, height, width) -> (1, dim, target_height, target_width)
+ height, width = spatial_shapes[i]
+ resized_embeddings = F.interpolate(
+ positional_embeddings,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ # (1, dim, target_height, target_width) ->
+ # (target_height * target_width, dim)
+ resized_embeddings = resized_embeddings.reshape(
+ embed_dim, height * width
+ ).transpose(0, 1)
+
+ # Cast to original dtype
+ resized_embeddings = resized_embeddings.to(source_dtype)
+
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
+
+ return resulted_positional_embeddings
+
+ def forward(
+ self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor
+ ) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Pixel values of shape (batch_size, max_num_patches,
+ num_channels * patch_size * patch_size)
+ spatial_shapes (`list[tuple[int, int]]`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional
+ embeddings to
+ """
+
+ # Apply patch embeddings to already patchified pixel values
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+
+ # Get positional resized and padded positional embeddings
+ positional_embeddings = self.position_embedding.weight.reshape(
+ self.position_embedding_size, self.position_embedding_size, -1
+ )
+ resized_positional_embeddings = self.resize_positional_embeddings(
+ positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
+ )
+
+ # Add positional embeddings to patch embeddings
+ embeddings = patch_embeds + resized_positional_embeddings
+ return embeddings
+
+
+class Siglip2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config: Siglip2VisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads "
+ f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ use_data_parallel = (
+ multimodal_config is not None
+ and multimodal_config.mm_encoder_tp_mode == "data"
+ )
+ tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
+ assert self.num_heads % tp_size == 0
+ self.num_heads_per_partition = self.num_heads // tp_size
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size=self.embed_dim,
+ head_size=self.head_dim,
+ total_num_heads=self.num_heads,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ disable_tp=use_data_parallel,
+ )
+ self.out_proj = RowParallelLinear(
+ input_size=self.embed_dim,
+ output_size=self.embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out_proj",
+ disable_tp=use_data_parallel,
+ )
+ self.attn = MMEncoderAttention(
+ num_heads=self.num_heads_per_partition,
+ head_size=self.head_dim,
+ scale=self.scale,
+ prefix=f"{prefix}.attn",
+ multimodal_config=multimodal_config,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int | torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(
+ hidden_states
+ ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
+ bsz, q_len, _ = qkv.shape
+ query_states, key_states, value_states = qkv.chunk(3, dim=-1)
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads_per_partition, self.head_dim
+ )
+ key_states = key_states.view(
+ bsz, q_len, self.num_heads_per_partition, self.head_dim
+ )
+ value_states = value_states.view(
+ bsz, q_len, self.num_heads_per_partition, self.head_dim
+ )
+
+ # Use unified MultiHeadAttention implementation
+ out = self.attn(
+ query=query_states,
+ key=key_states,
+ value=value_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ out = out.reshape(bsz, q_len, -1)
+ attn_output, _ = self.out_proj(out)
+ return attn_output
+
+
+class Siglip2MLP(nn.Module):
+ def __init__(
+ self,
+ config: Siglip2VisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.activation_fn = get_act_fn(config.hidden_act)
+ use_data_parallel = (
+ multimodal_config is not None
+ and multimodal_config.mm_encoder_tp_mode == "data"
+ )
+ self.fc1 = ColumnParallelLinear(
+ config.hidden_size,
+ config.intermediate_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc1",
+ disable_tp=use_data_parallel,
+ )
+ self.fc2 = RowParallelLinear(
+ config.intermediate_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc2",
+ disable_tp=use_data_parallel,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states, _ = self.fc2(hidden_states)
+ return hidden_states
+
+
+@support_torch_compile(
+ dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0},
+ enable_if=should_torch_compile_mm_vit,
+)
+class Siglip2EncoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Siglip2VisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.self_attn = Siglip2Attention(
+ config,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.self_attn",
+ )
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Siglip2MLP(
+ config,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.mlp",
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int | torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
+ cu_seqlens: Cumulative sequence lengths tensor.
+ max_seqlen: Maximum sequence length.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Siglip2Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers`
+ self attention layers. Each layer is a [`Siglip2EncoderLayer`].
+
+ Args:
+ config: PretrainedConfig
+ """
+
+ def __init__(
+ self,
+ config: Siglip2VisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Siglip2EncoderLayer(
+ config=config,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.layers.{idx}",
+ )
+ for idx in range(config.num_hidden_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int | torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ hidden_states = layer_outputs
+ return hidden_states
+
+
+class Siglip2VisionTransformer(nn.Module):
+ def __init__(
+ self,
+ config: Siglip2VisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.config = config
+ self.embeddings = Siglip2VisionEmbeddings(config)
+ # Keep the import local to avoid circular dependencies during model init.
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("Siglip2Encoder", is_encoder=True):
+ self.encoder = Siglip2Encoder(
+ config,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.encoder",
+ )
+ num_hidden_layers = config.num_hidden_layers
+ if len(self.encoder.layers) > config.num_hidden_layers:
+ raise ValueError(
+ f"The original encoder only has {num_hidden_layers} "
+ f"layers, but you requested {len(self.encoder.layers)} layers."
+ )
+
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ spatial_shapes: torch.LongTensor,
+ packed_mask: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int | torch.Tensor,
+ ) -> torch.Tensor:
+ r"""
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width)
+ of the input images.
+ """
+ hidden_states = self.embeddings(pixel_values, spatial_shapes)
+ flat_mask = packed_mask.view(-1)
+ packed_indices = flat_mask.nonzero(as_tuple=True)[0]
+ flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = flat_hidden_states.index_select(0, packed_indices).unsqueeze(0)
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ unpacked = encoder_outputs.new_zeros(
+ packed_mask.numel(), encoder_outputs.shape[-1]
+ )
+ unpacked.index_copy_(0, packed_indices, encoder_outputs.squeeze(0))
+ encoder_outputs = unpacked.view(
+ packed_mask.shape + (encoder_outputs.shape[-1],)
+ )
+ last_hidden_state = self.post_layernorm(encoder_outputs)
+ return last_hidden_state
+
+
+class Siglip2Model(torch.nn.Module):
+ def __init__(
+ self,
+ config: Siglip2VisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+
+ self.vision_model = Siglip2VisionTransformer(
+ config,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.vision_model",
+ )
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ spatial_shapes: torch.LongTensor,
+ packed_mask: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int | torch.Tensor,
+ ) -> torch.Tensor:
+ return self.vision_model(
+ pixel_values=pixel_values,
+ spatial_shapes=spatial_shapes,
+ packed_mask=packed_mask,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params