Skip to content
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
551e46c
feat: add colqwen2 (wip)
tonywu71 Jan 19, 2025
154d843
tests: fix test_attention_outputs
tonywu71 Apr 15, 2025
055eb5e
tests: reduce hidden size to accelerate tests
tonywu71 Apr 15, 2025
c0c7248
tests: fix `test_attention_outputs` 🥳
tonywu71 Apr 15, 2025
0a1e9f0
fix: fix wrong parent class for `ColQwen2ForRetrievalOutput`
tonywu71 Apr 15, 2025
99a5961
fix: minor typing and style changes
tonywu71 Apr 15, 2025
c731365
chore: run `make style`
tonywu71 Apr 15, 2025
5985784
feat: remove redundant `max_num_visual_tokens` attribute in `ColQwen2…
tonywu71 Apr 15, 2025
c6567d4
tests: tweak comments
tonywu71 Apr 15, 2025
0cb74d9
style: apply ruff formatter
tonywu71 Apr 15, 2025
6109920
feat: move default values for `visual_prompt_prefix` and `query_prefix`
tonywu71 Apr 15, 2025
b090847
docs: update ColQwen2 model card
tonywu71 Apr 16, 2025
607cd78
docs: tweak model cards
tonywu71 Apr 16, 2025
6c261cf
docs: add required example config checkpoint
tonywu71 Apr 16, 2025
b027a9d
tests: update expected scores in integration test
tonywu71 Apr 16, 2025
0302b12
docs: tweak quickstart snippets
tonywu71 Apr 16, 2025
5eaa32b
fix: address PR comments
tonywu71 Apr 16, 2025
ebb89b5
tests: fix colqwen2 tests + tweak comment in colpali test
tonywu71 Apr 16, 2025
bdbaa2b
tests: unskip useful tests
tonywu71 Apr 16, 2025
6fbac2a
fix: fix bug when `visual_prompt_prefix` or `query_prefix` is an empt…
tonywu71 Apr 16, 2025
6931500
fix: fix ColPali outputs when `return_dict == False`
tonywu71 Apr 16, 2025
985575c
fix: fix issue with PaliGemma output not being a dict
tonywu71 Apr 16, 2025
68ba7b8
docs: set default dtype to bfloat16 in quickstart snippets
tonywu71 Apr 16, 2025
bae3119
fix: fix error when `return_dict=False` in ColPali and ColQwen2
tonywu71 Apr 17, 2025
7dcc1e0
tests: fix special tokens not being replaced in input_ids
tonywu71 Apr 17, 2025
17882c2
style: fix lint
tonywu71 Apr 17, 2025
da93dcf
fix: `ColQwen2Processor`'s `padding_side` is now set from `processor_…
tonywu71 Apr 17, 2025
2b1ef88
fix: remove unused `padding_side` in ColQwen2 model
tonywu71 Apr 17, 2025
60d4033
docs: update ColQwen2's model doc
tonywu71 Apr 17, 2025
bb27ef9
fix: fix harcoded vlm backbone class in ColQwen2Config
tonywu71 Apr 17, 2025
a31b2f3
fix: remove `padding_side` from ColQwen2Processor as should fed from …
tonywu71 Apr 17, 2025
45fba97
docs: fix typo in model docstring
tonywu71 Apr 17, 2025
78d051d
docs: add illuin mention in model docs
tonywu71 Apr 17, 2025
ee9800b
fix: let `padding_size` be handled by `tokenizer_config.json`
tonywu71 Apr 17, 2025
4f76803
docs: add colpali reference url in colqwen2's model doc
tonywu71 Apr 17, 2025
cb924e3
docs: add Hf mention in model docs
tonywu71 Apr 17, 2025
f8f8261
docs: add late interaction mention in model docs
tonywu71 Apr 17, 2025
824b331
docs: tweak colqwen2 model doc
tonywu71 Apr 17, 2025
6dc1d22
docs: update reference checkpoint for ColPali to v1.3
tonywu71 Apr 18, 2025
d325c01
docs: simplify quickstart snippets
tonywu71 Apr 18, 2025
61a578c
docs: remove redundant `.eval()`
tonywu71 Apr 23, 2025
ff59eb2
refactor: use `can_return_tuple` decorator for ColPali and ColQwen2
tonywu71 Apr 23, 2025
f48568b
docs: fix copyright date
tonywu71 Apr 23, 2025
45d1dbe
docs: add missing copyright in tests
tonywu71 Apr 23, 2025
7b0f900
fix: raise error when `initializer_range` is not in config
tonywu71 Apr 23, 2025
f171ed6
docs: remove redundant `.eval()` in colpali doc
tonywu71 Apr 29, 2025
eaa797b
fix: fix `get_text_config` now that Qwen2VL has a proper `text_config…
tonywu71 Apr 29, 2025
c8e360f
fix: add missing `initializer_range` attribute in `ColQwen2Config`
tonywu71 Apr 29, 2025
14d7b5c
fix: use `get_text_config` in `resize_token_embeddings`
tonywu71 Apr 29, 2025
0686b2a
Merge remote-tracking branch 'upstream/main' into add-colqwen2
yonigozlan May 12, 2025
10b3ddb
update colwen2 with auto_docstring
yonigozlan May 12, 2025
bdef63f
docs: fix wrong copyright year
tonywu71 May 13, 2025
4b7f635
chore: remove `raise` as `initializer_range` has a default value in `…
tonywu71 May 13, 2025
c638c07
refactor: merge `inner_forward` into `forward`
tonywu71 May 13, 2025
30d2080
Merge remote-tracking branch 'upstream/main' into add-colqwen2
yonigozlan May 23, 2025
8277c43
Refactor colqwen2 after refactoring of qwen2VL, use modular for model…
yonigozlan May 23, 2025
86e0693
protect torch import in modular to protect in processing
yonigozlan May 23, 2025
c0a6442
protect torch import in modular to protect in processing
yonigozlan May 23, 2025
98a5338
Merge branch 'add-colqwen2' of https://github.com/tonywu71/transforme…
yonigozlan May 23, 2025
4aa5aa0
tests: fix hf model path in ColQwen2 integration test
tonywu71 May 24, 2025
34ca1e7
docs: clarify `attn_implementation` and add comments
tonywu71 May 29, 2025
43af0ad
docs: add fallback snippet for using offline PIL dummy images
tonywu71 May 29, 2025
0356f3c
docs: temporarily revert attn_implementation to `None` while sdpa is …
tonywu71 May 29, 2025
7a4218b
docs: tweaks in colpali/colqwen2 quick start snippets
tonywu71 May 29, 2025
58c7ff2
fix: add missing flags to enable SDPA/Flex Attention in ColQwen2 model
tonywu71 May 30, 2025
3852c86
fix: add missing changes in modular file
tonywu71 May 30, 2025
bd65ad3
Merge remote-tracking branch 'upstream/main' into add-colqwen2
yonigozlan Jun 2, 2025
1bc3dea
fix modeling tests
yonigozlan Jun 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@
title: CLVP
- local: model_doc/colpali
title: ColPali
- local: model_doc/colqwen2
title: ColQwen2
- local: model_doc/data2vec
title: Data2Vec
- local: model_doc/deplot
Expand Down
46 changes: 27 additions & 19 deletions docs/source/en/model_doc/colpali.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,33 @@ rendered properly in your Markdown viewer.

# ColPali

[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed embeddings. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.

You can find all the original ColPali checkpoints under the [ColPali](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).

You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.

> [!TIP]
> Click on the ColPali models in the right sidebar for more examples of how to use ColPali for image retrieval.

<hfoptions id="usage">
<hfoption id="image retrieval">

```py
```python
import requests
import torch
from PIL import Image

from transformers import ColPaliForRetrieval, ColPaliProcessor

# Load model (bfloat16 support is limited; fallback to float32 if needed)

model_name = "vidore/colpali-v1.3-hf"

model = ColPaliForRetrieval.from_pretrained(
"vidore/colpali-v1.2-hf",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
).eval()
)

processor = ColPaliProcessor.from_pretrained(model_name)

Expand All @@ -54,37 +59,42 @@ images = [
]

queries = [
"Who printed the edition of Romeo and Juliet?",
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]

# Process the inputs
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
inputs_images = processor(images=images).to(model.device)
inputs_text = processor(text=queries).to(model.device)

# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings

# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)

print("Retrieval scores (query x image):")
print(scores)
```

</hfoption>
</hfoptions>

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.

The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.

```py
```python
import requests
import torch
from PIL import Image
from transformers import ColPaliForRetrieval, ColPaliProcessor
from transformers import BitsAndBytesConfig

from transformers import BitsAndBytesConfig, ColPaliForRetrieval, ColPaliProcessor


model_name = "vidore/colpali-v1.3-hf"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
Expand All @@ -94,14 +104,11 @@ bnb_config = BitsAndBytesConfig(
bnb_4bit_compute_dtype=torch.float16,
)

model_name = "vidore/colpali-v1.2-hf"

# Load model
model = ColPaliForRetrieval.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="cuda"
).eval()
device_map="cuda",
)

processor = ColPaliProcessor.from_pretrained(model_name)

Expand All @@ -114,8 +121,8 @@ images = [
]

queries = [
"Who printed the edition of Romeo and Juliet?",
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]

# Process the inputs
Expand All @@ -127,6 +134,7 @@ with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings

# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)

print("Retrieval scores (query x image):")
Expand Down
166 changes: 166 additions & 0 deletions docs/source/en/model_doc/colqwen2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>

# ColQwen2

[ColQwen2](https://doi.org/10.48550/arXiv.2407.01449) is a variant of the [ColPali](./colpali) model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColQwen2 treats each page as an image. It uses the [Qwen2-VL](./qwen2_vl) backbone to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.

This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).

You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.

> [!TIP]
> Click on the ColQwen2 models in the right sidebar for more examples of how to use ColQwen2 for image retrieval.
<hfoptions id="usage">
<hfoption id="image retrieval">

```python
import requests
import torch
from PIL import Image

from transformers import ColQwen2ForRetrieval, ColQwen2Processor
from transformers.utils.import_utils import is_flash_attn_2_available


model_name = "vidore/colqwen2-v1.0-hf"

# Load model
model = ColQwen2ForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a super nit, feel free to disregard if you're too annoyed by the review process 😆 But passing None is a bit misleading for an example IMO, even if it's equivalent

Suggested change
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
)
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! It's been addressed 👌🏼

Copy link
Contributor Author

@tonywu71 tonywu71 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it seems sdpa doesn't work out-of-the-box for ColQwen2 as I get this error when loading the model on MPS.

❌ Code:

model_name = "vidore/colqwen2-v1.0-hf"

# Load model
model = ColQwen2ForRetrieval.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # "cpu", "cuda", or "mps" for Apple Silicon
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)

Note: Leaving attn_implementation=None works.

The error:

ValueError: ColQwen2ForRetrieval does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`

✅ However, I managed to load Qwen2VL with SDPA:

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto",  # "cpu", "cuda", or "mps" for Apple Silicon
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)

@Cyrilvallez @yonigozlan I read about the instructions for enabling SDPA on ColQwen2 but next steps are a bit unclear as ColQwen2 essentially piggybacks on Qwen2VL thanks to modular. Any ideas about the right fix? 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's only because the flags are not set in the PreTrainedModel - adding

_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True

should solve it

Copy link
Contributor Author

@tonywu71 tonywu71 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tsm, the fix is working like a charm! And as you expected, ColQwen2 works with attn_implementation="flex_attention" too 👌🏼


processor = ColQwen2Processor.from_pretrained(model_name)

url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"

images = [
Image.open(requests.get(url1, stream=True).raw),
Image.open(requests.get(url2, stream=True).raw),
]

queries = [
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]

# Process the inputs
inputs_images = processor(images=images).to(model.device)
inputs_text = processor(text=queries).to(model.device)

# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings

# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)

print("Retrieval scores (query x image):")
print(scores)
```

</hfoption>
</hfoptions>

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.

The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.

```python
import requests
import torch
from PIL import Image

from transformers import BitsAndBytesConfig, ColQwen2ForRetrieval, ColQwen2Processor


model_name = "vidore/colqwen2-v1.0-hf"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

model = ColQwen2ForRetrieval.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="cuda",
).eval()

processor = ColQwen2Processor.from_pretrained(model_name)

url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"

images = [
Image.open(requests.get(url1, stream=True).raw),
Image.open(requests.get(url2, stream=True).raw),
]

queries = [
"When was the United States Declaration of Independence proclaimed?",
"Who printed the edition of Romeo and Juliet?",
]

# Process the inputs
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)

# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings

# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)

print("Retrieval scores (query x image):")
print(scores)
```

## Notes

- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
- Unlike ColPali, ColQwen2 supports arbitrary image resolutions and aspect ratios, which means images are not resized into fixed-size squares. This preserves more of the original input signal.
- Larger input images generate longer multi-vector embeddings, allowing users to adjust image resolution to balance performance and memory usage.

## ColQwen2Config

[[autodoc]] ColQwen2Config

## ColQwen2Processor

[[autodoc]] ColQwen2Processor

## ColQwen2ForRetrieval

[[autodoc]] ColQwen2ForRetrieval
- forward
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from .cohere import *
from .cohere2 import *
from .colpali import *
from .colqwen2 import *
from .conditional_detr import *
from .convbert import *
from .convnext import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
("cohere", "CohereConfig"),
("cohere2", "Cohere2Config"),
("colpali", "ColPaliConfig"),
("colqwen2", "ColQwen2Config"),
("conditional_detr", "ConditionalDetrConfig"),
("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"),
Expand Down Expand Up @@ -433,6 +434,7 @@
("cohere", "Cohere"),
("cohere2", "Cohere2"),
("colpali", "ColPali"),
("colqwen2", "ColQwen2"),
("conditional_detr", "Conditional DETR"),
("convbert", "ConvBERT"),
("convnext", "ConvNeXT"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
("colpali", "ColPaliForRetrieval"),
("colqwen2", "ColQwen2ForRetrieval"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
("clipseg", "CLIPSegProcessor"),
("clvp", "ClvpProcessor"),
("colpali", "ColPaliProcessor"),
("colqwen2", "ColQwen2Processor"),
("emu3", "Emu3Processor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"cpm",
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/colpali/configuration_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ class ColPaliConfig(PretrainedConfig):
Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).

The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension.

Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.

Expand Down Expand Up @@ -93,7 +91,7 @@ def __init__(
)

self.vlm_config = vlm_config
self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config
self.text_config = text_config if text_config is not None else vlm_config.text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
Expand Down
Loading