diff --git a/docs/source/en/model_doc/colqwen2.md b/docs/source/en/model_doc/colqwen2.md index 6810ca529a04..7c9a9627e2c7 100644 --- a/docs/source/en/model_doc/colqwen2.md +++ b/docs/source/en/model_doc/colqwen2.md @@ -158,6 +158,24 @@ print("Retrieval scores (query x image):") print(scores) ``` +You can also use checkpoints for `ColQwen2.5` that are **compatible with the ColQwen2 architecture**. This version of the model uses [Qwen2_5_VL](./qwen2_5_vl) as the backbone. + +```python +import torch +from transformers import ColQwen2ForRetrieval, ColQwen2Processor +from transformers.utils.import_utils import is_flash_attn_2_available + +model_name = "Sahil-Kabir/colqwen2.5-v0.2-hf" # An existing compatible checkpoint + +model = ColQwen2ForRetrieval.from_pretrained( + model_name, + dtype=torch.bfloat16, + device_map="auto", + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa" +) +processor = ColQwen2Processor.from_pretrained(model_name) +``` + ## Notes - [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image. diff --git a/src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py b/src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py index ca990a6d42d4..e8fbc502466c 100644 --- a/src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py +++ b/src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py @@ -39,9 +39,10 @@ import torch from huggingface_hub import snapshot_download +from peft import PeftModel from safetensors import safe_open -from transformers import AutoConfig +from transformers import AutoConfig, AutoModel from transformers.models.colqwen2 import ColQwen2ForRetrieval from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config from transformers.utils import logging @@ -69,7 +70,7 @@ def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> d original_state_dict[key] = f.get_tensor(key) # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict. - if "lm_head.weight" not in original_state_dict: + if "lm_head.weight" not in original_state_dict and "model.embed_tokens.weight" in original_state_dict: original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone() return original_state_dict @@ -124,7 +125,21 @@ def convert_colqwen2_weights_to_hf( config.is_composition = False # Load the untrained model - model = ColQwen2ForRetrieval(config=config).to("cpu").eval() + vlm_name_or_path = getattr(config.vlm_config, "_name_or_path", None) + if vlm_name_or_path and "2.5" in str(vlm_name_or_path): + print( + "Detected colqwen2.5 adapters in vlm_config; loading base model %s and merging PEFT weights." + % vlm_name_or_path + ) + base_model = AutoModel.from_pretrained( + vlm_name_or_path, + device_map="cpu", + trust_remote_code=True, + ) + peft_model = PeftModel.from_pretrained(base_model, model_id) + model = peft_model.merge_and_unload() + else: + model = ColQwen2ForRetrieval(config=config).to("cpu").eval() print("Created model with new config and randomly initialized weights") # NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision. @@ -201,6 +216,7 @@ def convert_colqwen2_weights_to_hf( help="Name or path of the original VLM backbone model", default=None, ) + args = parser.parse_args() convert_colqwen2_weights_to_hf( diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 0c22fb99c887..c3a6c04ee4db 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -172,7 +172,6 @@ def forward( inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) image_mask = ( (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index adea1617e459..072591abbab8 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -359,7 +359,6 @@ def forward( inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) image_mask = ( (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) diff --git a/tests/models/colqwen2/test_modeling_colqwen2.py b/tests/models/colqwen2/test_modeling_colqwen2.py index 790cf639c985..4d9d24703682 100644 --- a/tests/models/colqwen2/test_modeling_colqwen2.py +++ b/tests/models/colqwen2/test_modeling_colqwen2.py @@ -335,12 +335,61 @@ def test_model_integration_test(self): [15.6562, 12.2656, 20.2969], ], ("cuda", 8): [ - [15.0703, 8.7422, 15.0312], - [9.5078, 16.8906, 10.6250], - [15.6484, 12.3984, 20.4688], + [16.2812, 8.3672, 14.5703], + [9.4922, 17.1875, 10.3281], + [15.0312, 11.3984, 20.1719], ], } ) expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype) assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}" + + @slow + def test_model_integration_test_2(self): + """ + Test if the model is able to retrieve the correct pages for a small and easy dataset. + This test uses a ColQwen2.5 checkpoint that is compatible with the ColQwen2 architecture. + """ + model = ColQwen2ForRetrieval.from_pretrained( + "Sahil-Kabir/colqwen2.5-v0.2-hf", + device_map=torch_device, + dtype=torch.bfloat16, + ).eval() + processor = ColQwen2Processor.from_pretrained("Sahil-Kabir/colqwen2.5-v0.2-hf", trust_remote_code=True) + + # Load the test dataset + ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") + + # Preprocess the examples + batch_images = processor(images=list(ds["image"])).to(torch_device) + batch_queries = processor(text=list(ds["query"])).to(torch_device) + + with torch.inference_mode(): + image_embeddings = model(**batch_images).embeddings + query_embeddings = model(**batch_queries).embeddings + + # Compute retrieval scores + scores = processor.score_retrieval( + query_embeddings=query_embeddings, + passage_embeddings=image_embeddings, + ) + + assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" + assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" + + # Check if the maximum scores per row are in the diagonal of the matrix score + self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all()) + # Further validation: fine-grained check, with a hardcoded score from the original Hf implementation. + expectations = Expectations( + { + ("cuda", 8): [ + [16.3750, 10.9375, 14.7500], + [11.3750, 16.8750, 12.0625], + [15.3125, 13.1250, 21.5000], + ] + } + ) + expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype) + + assert torch.allclose(scores, expected_scores, atol=0.15), f"Expected scores {expected_scores}, got {scores}"