Skip to content

Commit a957b79

Browse files
authored
Add SigLIP 2 (#36323)
* Docs * Inits * Auto classes * Add siglip base * Add base tests * Fix Siglip V1 for fix res version * Add image processor * Update conversion * Experimenting with vectorized embeddings * Fixup * Add modular Siglip2Processor * Add modular configuration * Rename num patches * Correct image and text features merging * Working conversion script * Refactoring conversion script * Remove unused code in conversion script * Shorten dict a bit * Refactoring conversion * Done conversion refactoring * Fixup * Modular siglip2 * Make model exportable and compilable without graph breaks * Remove position_ids from image_processor * REmove position ids from modeling file * Update modular * Type hint * Fixup * Set defaults to processor * Add integration test * Revert spatial shapes back to tensor * Change order * Fix most of the tests * Fix docstring * Remove interpolate_pos_encoding arg (not needed) * Update docs * Standardize processing * Fix attention_mask in vision head * Siglip v1: remove double transpose in FA2 * Update modular file * Update FA2 test * Update expected logits * Fix interpolation for siglip2 image processor * Skip init test * Skip dispatch on flash test * Fix modeling tests * Fixup * Add dummy objects * Fix some docstrings * Add siglip2 in index.md * Fix consistency * Add docs * Remove size and data format * Add image processor tests * Fix * Add fast image processor * Fix style * Fix * Docs * Set lowercase for tokenizer * Adjust head size for Siglip v1 * Update siglip2 for consistency with siglip1 * Update siglip2 conversion * Update pipeline * Update checkpoints in tests * Update checkpoint name * Fix pooling for image classification model * Fix FA2 test * Update processor * Fix check repo * Update docs * Fix typos * Fix docstring for fast image processor * Add siglip2 to FA2 docs * Fix fast ip tests * Fix constitency * Fix tokenizer class for siglip v1 * Fix missing header * Refactor scaling for clip, siglip, siglip2 * Remove unused imports * Make fast IP default for siglip2 * Update docs * Update checkpoints * Update modular * Update paper link * Fixup * Fix name in toctree * Fix test
1 parent 14552cb commit a957b79

33 files changed

+5570
-122
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,8 @@
965965
title: Segment Anything
966966
- local: model_doc/siglip
967967
title: SigLIP
968+
- local: model_doc/siglip2
969+
title: SigLIP2
968970
- local: model_doc/smolvlm
969971
title: SmolVLM
970972
- local: model_doc/speech-encoder-decoder

docs/source/en/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ Flax), PyTorch, and/or TensorFlow.
317317
| [SEW](model_doc/sew) ||||
318318
| [SEW-D](model_doc/sew-d) ||||
319319
| [SigLIP](model_doc/siglip) ||||
320+
| [SigLIP2](model_doc/siglip2) ||||
320321
| [SmolVLM](model_doc/smolvlm) ||||
321322
| [Speech Encoder decoder](model_doc/speech-encoder-decoder) ||||
322323
| [Speech2Text](model_doc/speech_to_text) ||||

docs/source/en/model_doc/siglip2.md

+276
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# SigLIP2
18+
19+
## Overview
20+
21+
The SigLIP2 model was proposed in [SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features](https://huggingface.co/papers/2502.14786) by Michael Tschannen, Alexey Gritsenko, Xiao Wang, Muhammad Ferjad Naeem, Ibrahim Alabdulmohsin,
22+
Nikhil Parthasarathy, Talfan Evans, Lucas Beyer, Ye Xia, Basil Mustafa, Olivier Hénaff, Jeremiah Harmsen,
23+
Andreas Steiner and Xiaohua Zhai.
24+
25+
The model comes in two variants
26+
27+
1) FixRes - model works with fixed resolution images (backward compatible with SigLIP v1)
28+
2) NaFlex - model works with variable image aspect ratios and resolutions (SigLIP2 in `transformers`)
29+
30+
The abstract from the paper is the following:
31+
32+
*We introduce SigLIP 2, a family of new multilingual vision-language encoders that build on the success
33+
of the original SigLIP. In this second iteration, we extend the original image-text training objective with
34+
several prior, independently developed techniques into a unified recipe—this includes decoder-based
35+
pretraining, self-supervised losses (self-distillation, masked prediction) and online data curation. With
36+
these changes, SigLIP 2 models outperform their SigLIP counterparts at all model scales in core capabilities,
37+
including zero-shot classification (best SigLIP 2 ViT-g/16 achieves 85.0% ImageNet zero-shot
38+
accuracy), image-text retrieval, and transfer performance when extracting visual representations for
39+
Vision-Language Models (VLMs). Furthermore, the new training recipe leads to significant improvements
40+
on localization and dense prediction tasks. We also train variants which support multiple resolutions
41+
and preserve the input’s native aspect ratio. Finally, we train on a more diverse data-mixture that
42+
includes de-biasing techniques, leading to much better multilingual understanding and improved fair-
43+
ness. To provide users with the ability to trade-off inference cost with performance, we release model
44+
checkpoints at four sizes (ViT-B/86M, L/303M, So400m/400M, and g/1B).*
45+
46+
## Usage tips
47+
48+
- Usage of SigLIP2 is similar to [SigLIP](siglip) and [CLIP](clip). The main difference from CLIP is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax.
49+
- Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup.
50+
- When using the standalone [`GemmaTokenizerFast`] make sure to pass `padding="max_length"` and `max_length=64` as that's how the model was trained.
51+
- Model was trained with *lowercased* text, make sure you make the same preprocessing for your text labels.
52+
- To get the same results as the pipeline, a prompt template of "this is a photo of {label}" should be used.
53+
- The NaFlex variant supports processing images at higher resolutions by adjusting the `max_num_patches` parameter in the `Processor`. The default value is `max_num_patches=256`. Increasing `max_num_patches` to 1024 (4x) will approximately double processed image height and width, while preserving the aspect ratio.
54+
55+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/siglip2_metrics_table.png"
56+
alt="drawing" width="600"/>
57+
58+
This model was contributed by [qubvel](https://huggingface.co/qubvel-hf).
59+
The original code can be found [here](https://github.com/google-research/big_vision/tree/main).
60+
61+
## Usage example
62+
63+
There are 2 main ways to use SigLIP2: either using the pipeline API, which abstracts away all the complexity for you, or by using the `Siglip2Model` class yourself.
64+
65+
### FixRes variant
66+
67+
**Pipeline API**
68+
69+
The pipeline allows to use the model in a few lines of code:
70+
71+
```python
72+
>>> from transformers import pipeline
73+
>>> from PIL import Image
74+
>>> import requests
75+
76+
>>> # load pipe
77+
>>> image_classifier = pipeline(
78+
... task="zero-shot-image-classification",
79+
... model="google/siglip2-base-patch16-224",
80+
... )
81+
82+
>>> # load image
83+
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
84+
>>> image = Image.open(requests.get(url, stream=True).raw)
85+
86+
>>> # inference
87+
>>> candidate_labels = ["2 cats", "a plane", "a remote"]
88+
>>> outputs = image_classifier(image, candidate_labels=candidate_labels)
89+
>>> outputs = [{"score": round(output["score"], 4), "label": output["label"] } for output in outputs]
90+
>>> print(outputs)
91+
[{'score': 0.1499, 'label': '2 cats'}, {'score': 0.0008, 'label': 'a remote'}, {'score': 0.0, 'label': 'a plane'}]
92+
```
93+
94+
**Using the model yourself**
95+
96+
If you want to do the pre- and postprocessing yourself, here's how to do that:
97+
98+
```python
99+
>>> from PIL import Image
100+
>>> import requests
101+
>>> from transformers import AutoProcessor, AutoModel
102+
>>> import torch
103+
104+
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
105+
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
106+
107+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
108+
>>> image = Image.open(requests.get(url, stream=True).raw)
109+
110+
>>> candidate_labels = ["2 cats", "2 dogs"]
111+
# follows the pipeline prompt template to get same results
112+
>>> texts = [f"This is a photo of {label}." for label in candidate_labels]
113+
114+
# IMPORTANT: we pass `padding=max_length` and `max_length=64` since the model was trained with this
115+
>>> inputs = processor(text=texts, images=image, padding="max_length", max_length=64, return_tensors="pt")
116+
117+
>>> with torch.no_grad():
118+
... outputs = model(**inputs)
119+
120+
>>> logits_per_image = outputs.logits_per_image
121+
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
122+
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
123+
15.0% that image 0 is '2 cats'
124+
```
125+
126+
### NaFlex variant
127+
128+
NaFlex combines ideas from FlexiViT, i.e. supporting multiple, predefined sequence lengths
129+
with a single ViT model, and NaViT, namely processing images at their native aspect ratio.
130+
This enables processing different types of images at appropriate resolution, e.g. using a
131+
larger resolution to process document images, while at the same time minimizing the impact
132+
of aspect ratio distortion on certain inference tasks, e.g. on OCR.
133+
134+
Given a patch size and target sequence length, NaFlex preprocesses the data by first resizing
135+
the input image such that the height and width after resizing are multiples of the patch size,
136+
while
137+
138+
1. keeping the aspect ratio distortion as small as possible
139+
2. producing a sequence length of at most the desired target sequence length (`max_num_patches`)
140+
141+
The resulting distortion in width and height is at most `(patch_size - 1) / width` and
142+
`(patch_size - 1) / height`, respectively, which tends to be small for common resolutions and aspect ratios.
143+
After resizing, the image is split into a sequence of patches, and a mask with padding information is added.
144+
145+
```python
146+
>>> from PIL import Image
147+
>>> import requests
148+
>>> from transformers import AutoProcessor, AutoModel
149+
>>> import torch
150+
151+
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-naflex")
152+
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")
153+
154+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
155+
>>> image = Image.open(requests.get(url, stream=True).raw)
156+
157+
>>> candidate_labels = ["2 cats", "2 dogs"]
158+
# follows the pipeline prompt template to get same results
159+
>>> texts = [f"This is a photo of {label}." for label in candidate_labels]
160+
161+
# default value for `max_num_patches` is 256, but you can increase resulted image resolution providing
162+
# higher values e.g. `max_num_patches=512`
163+
>>> inputs = processor(text=texts, images=image, max_num_patches=256, return_tensors="pt")
164+
165+
>>> with torch.no_grad():
166+
... outputs = model(**inputs)
167+
168+
>>> logits_per_image = outputs.logits_per_image
169+
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
170+
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
171+
21.1% that image 0 is '2 cats'
172+
```
173+
174+
## Resources
175+
176+
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SigLIP2.
177+
178+
- [Zero-shot image classification task guide](../tasks/zero_shot_image_classification)
179+
- Demo notebook for SigLIP2 can be found [here](https://github.com/qubvel/transformers-notebooks/tree/master/notebooks/SigLIP2_inference.ipynb). 🌎
180+
181+
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
182+
183+
184+
## Combining SigLIP2 and Flash Attention 2
185+
186+
First, make sure to install the latest version of Flash Attention 2.
187+
188+
```bash
189+
pip install -U flash-attn --no-build-isolation
190+
```
191+
192+
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
193+
194+
To load and run a model using Flash Attention 2, refer to the snippet below:
195+
196+
```python
197+
>>> import torch
198+
>>> import requests
199+
>>> from PIL import Image
200+
>>> from transformers import AutoProcessor, AutoModel
201+
>>> device = "cuda" # the device to load the model onto
202+
203+
>>> model = AutoModel.from_pretrained(
204+
... "google/siglip2-so400m-patch14-384",
205+
... attn_implementation="flash_attention_2",
206+
... torch_dtype=torch.float16,
207+
... device_map=device,
208+
... )
209+
>>> processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch14-384")
210+
211+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
212+
>>> image = Image.open(requests.get(url, stream=True).raw)
213+
214+
>>> candidate_labels = ["2 cats", "2 dogs"]
215+
# follows the pipeline prompt template to get same results
216+
>>> texts = [f'This is a photo of {label}.' for label in candidate_labels]
217+
# important: we pass `padding=max_length` since the model was trained with this
218+
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt").to(device)
219+
220+
>>> with torch.no_grad():
221+
... with torch.autocast(device):
222+
... outputs = model(**inputs)
223+
224+
>>> logits_per_image = outputs.logits_per_image
225+
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
226+
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
227+
19.8% that image 0 is '2 cats'
228+
```
229+
230+
## Siglip2Config
231+
232+
[[autodoc]] Siglip2Config
233+
234+
## Siglip2TextConfig
235+
236+
[[autodoc]] Siglip2TextConfig
237+
238+
## Siglip2VisionConfig
239+
240+
[[autodoc]] Siglip2VisionConfig
241+
242+
## Siglip2ImageProcessor
243+
244+
[[autodoc]] Siglip2ImageProcessor
245+
- preprocess
246+
247+
## Siglip2ImageProcessorFast
248+
249+
[[autodoc]] Siglip2ImageProcessorFast
250+
- preprocess
251+
252+
## Siglip2Processor
253+
254+
[[autodoc]] Siglip2Processor
255+
256+
## Siglip2Model
257+
258+
[[autodoc]] Siglip2Model
259+
- forward
260+
- get_text_features
261+
- get_image_features
262+
263+
## Siglip2TextModel
264+
265+
[[autodoc]] Siglip2TextModel
266+
- forward
267+
268+
## Siglip2VisionModel
269+
270+
[[autodoc]] Siglip2VisionModel
271+
- forward
272+
273+
## Siglip2ForImageClassification
274+
275+
[[autodoc]] Siglip2ForImageClassification
276+
- forward

docs/source/en/perf_infer_gpu_one.md

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ FlashAttention-2 is currently supported for the following architectures:
111111
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
112112
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
113113
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
114+
* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2)
114115
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
115116
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
116117
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
@@ -310,6 +311,7 @@ For now, Transformers supports SDPA inference and training for the following arc
310311
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
311312
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
312313
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
314+
* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2)
313315
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
314316
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
315317
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)

src/transformers/__init__.py

+32
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,12 @@
776776
"SiglipTextConfig",
777777
"SiglipVisionConfig",
778778
],
779+
"models.siglip2": [
780+
"Siglip2Config",
781+
"Siglip2Processor",
782+
"Siglip2TextConfig",
783+
"Siglip2VisionConfig",
784+
],
779785
"models.smolvlm": ["SmolVLMConfig"],
780786
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
781787
"models.speech_to_text": [
@@ -1289,6 +1295,7 @@
12891295
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
12901296
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
12911297
_import_structure["models.siglip"].append("SiglipImageProcessor")
1298+
_import_structure["models.siglip2"].append("Siglip2ImageProcessor")
12921299
_import_structure["models.smolvlm"].extend(["SmolVLMImageProcessor"])
12931300
_import_structure["models.superglue"].extend(["SuperGlueImageProcessor"])
12941301
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
@@ -1330,6 +1337,7 @@
13301337
_import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast")
13311338
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
13321339
_import_structure["models.siglip"].append("SiglipImageProcessorFast")
1340+
_import_structure["models.siglip2"].append("Siglip2ImageProcessorFast")
13331341
_import_structure["models.vit"].append("ViTImageProcessorFast")
13341342

13351343
try:
@@ -3559,6 +3567,15 @@
35593567
"SiglipVisionModel",
35603568
]
35613569
)
3570+
_import_structure["models.siglip2"].extend(
3571+
[
3572+
"Siglip2ForImageClassification",
3573+
"Siglip2Model",
3574+
"Siglip2PreTrainedModel",
3575+
"Siglip2TextModel",
3576+
"Siglip2VisionModel",
3577+
]
3578+
)
35623579
_import_structure["models.smolvlm"].extend(
35633580
[
35643581
"SmolVLMForConditionalGeneration",
@@ -5942,6 +5959,12 @@
59425959
SiglipTextConfig,
59435960
SiglipVisionConfig,
59445961
)
5962+
from .models.siglip2 import (
5963+
Siglip2Config,
5964+
Siglip2Processor,
5965+
Siglip2TextConfig,
5966+
Siglip2VisionConfig,
5967+
)
59455968
from .models.smolvlm import SmolVLMConfig
59465969
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
59475970
from .models.speech_to_text import (
@@ -6472,6 +6495,7 @@
64726495
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
64736496
from .models.seggpt import SegGptImageProcessor
64746497
from .models.siglip import SiglipImageProcessor
6498+
from .models.siglip2 import Siglip2ImageProcessor
64756499
from .models.smolvlm import SmolVLMImageProcessor
64766500
from .models.superglue import SuperGlueImageProcessor
64776501
from .models.superpoint import SuperPointImageProcessor
@@ -6509,6 +6533,7 @@
65096533
from .models.qwen2_vl import Qwen2VLImageProcessorFast
65106534
from .models.rt_detr import RTDetrImageProcessorFast
65116535
from .models.siglip import SiglipImageProcessorFast
6536+
from .models.siglip2 import Siglip2ImageProcessorFast
65126537
from .models.vit import ViTImageProcessorFast
65136538

65146539
try:
@@ -8288,6 +8313,13 @@
82888313
SiglipTextModel,
82898314
SiglipVisionModel,
82908315
)
8316+
from .models.siglip2 import (
8317+
Siglip2ForImageClassification,
8318+
Siglip2Model,
8319+
Siglip2PreTrainedModel,
8320+
Siglip2TextModel,
8321+
Siglip2VisionModel,
8322+
)
82918323
from .models.smolvlm import (
82928324
SmolVLMForConditionalGeneration,
82938325
SmolVLMModel,

0 commit comments

Comments
 (0)