From 471bfb06ea53115d74f73875ac77b457858cc2a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E7=AD=96?= Date: Mon, 5 Jan 2026 04:49:29 -0500 Subject: [PATCH] [Model] Enable LoRA support for Pixtral MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 赵策 Signed-off-by: <> Signed-off-by: 赵策 Signed-off-by: 赵策 --- docs/models/supported_models.md | 2 +- vllm/model_executor/models/pixtral.py | 31 +++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 07b1ced5ca42..832322a8e5bf 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -719,7 +719,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | -| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ | +| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | ✅︎ | ✅︎ | | `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 555e6ea4b8cb..2c5a5eedd74c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -63,7 +63,13 @@ from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .module_mapping import MultiModelKeys from .utils import init_vllm_registered_model, maybe_prefix from .vision import ( VisionEncoderInfo, @@ -365,7 +371,9 @@ def _cached_apply_hf_processor( info=PixtralProcessingInfo, dummy_inputs=PixtralDummyInputsBuilder, ) -class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): +class PixtralForConditionalGeneration( + nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP +): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): @@ -581,6 +589,25 @@ def llm_weights_generator(): # Now we call the language model load with the generator self.language_model.load_weights(llm_weights_generator()) + def get_mm_mapping(self) -> MultiModelKeys: + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="vision_language_adapter", + tower_model="vision_encoder", + ) + + def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: + if getattr(self, "patch_merger", None) is None: + return num_image_tokens + merge_size = self.vision_args.spatial_merge_size + return num_image_tokens * (merge_size**2) + + def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int: + if getattr(self, "patch_merger", None) is None: + return num_vision_tokens + merge_size = self.vision_args.spatial_merge_size + return num_vision_tokens // (merge_size**2) + # Vision encoder @dataclass