Skip to content

Conversation

@YushunXiang
Copy link
Contributor

@YushunXiang YushunXiang commented Jun 10, 2025

What does this PR do?

In the v4.52.1 release of the transformers library, PR #37033 @zucchini-nlp introduced a bug by renaming the class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin) to class PaliGemmaModel(PaliGemmaPreTrainedModel), which causes the original get_image_features function in line 218, huggingface/lerobot/common/policies/pi0/paligemma_with_expert.py to be unusable.

This pull request adds a new get_image_features method across multiple generative model implementations in the src/transformers/models directory. The method provides a standardized interface for extracting image features from models, with variations in parameters depending on the specific model's requirements.

I modify 6 files with adding the method get_image_features to corresponding class <model name>ForConditionalGeneration:

  • src/transformers/models/idefics2/modeling_idefics2.py
  • src/transformers/models/llava/modeling_llava.py
  • src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
  • src/transformers/models/chameleon/modeling_chameleon.py
  • src/transformers/models/paligemma/modeling_paligemma.py
  • src/transformers/models/video_llava/modeling_video_llava.py

and use make fix-copies to generate the other 13 modeling files.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts, @qubvel, @zucchini-nlp

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Indeed it is a breaking change and we need to allow access through the generative model as well

Can you update other VLMs as well, since the changes were done on all models?

@YushunXiang
Copy link
Contributor Author

Nice catch! Indeed it is a breaking change and we need to allow access through the generative model as well

Can you update other VLMs as well, since the changes were done on all models?

Yes, I would like to be a contributor to transformers.

Besides, I have a question is that:

if hasattr(self.paligemma, "get_image_features"):
    return self.paligemma.get_image_features(image)
else:
    return self.paligemma.model.get_image_features(image)

With this modification I can run the generated models no matter what transformers version, but is it a good modification?

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Jun 10, 2025

With this modification I can run the generated models no matter what transformers version, but is it a good modification?

yeah, it can be used as a hacky workaround in your code for transformers==v4.52. And for the transformers codebase, your solution looks good and can be propagated to all models. We'll then include your fixes for the next release :)

Btw, after the fixes you need to run make fix-copies to make our CI happy

@YushunXiang
Copy link
Contributor Author

YushunXiang commented Jun 10, 2025

I modify 6 files with adding the method get_image_features to corresponding class <model name>ForConditionalGeneration:

  • src/transformers/models/idefics2/modeling_idefics2.py
  • src/transformers/models/llava/modeling_llava.py
  • src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
  • src/transformers/models/chameleon/modeling_chameleon.py
  • src/transformers/models/paligemma/modeling_paligemma.py
  • src/transformers/models/video_llava/modeling_video_llava.py

and use make fix-copies to generate the other 13 modeling files.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ❤️

Just left a few comments about missing parts, some models use different helpers for video modality or quantized vision tokens

Comment on lines +1232 to +1234
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for chameleon and emu3, the helper was get_image_tokens. Can we propagate that too?

vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
return self.model.get_image_features(pixel_values_images, vision_feature_layer, vision_feature_select_strategy)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen-VL has no vision_feature_layer and vision_feature_select_strategy, naming should be the same as in base model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is my fault. Thanks!

Comment on lines 1390 to 1397
def get_image_features(
self,
pixel_values_images: torch.FloatTensor,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
return self.model.get_image_features(pixel_values_images, vision_feature_layer, vision_feature_select_strategy)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

def get_decoder(self):
return self.model

def get_image_features(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add the get_video_features when it exists. I think it's only for llava-onevision, llava-next-video and video-llava

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

… across multiple models

modified:
- modeling_chameleon.py
- modeling_llava_next.py
- modular_llava_next_video.py
- modeling_qwen2_vl.py

and generate the:
- modeling_llava_next_video.py
- modeling_llava_onevision.py
- modeling_qwen2_5_vl.py
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! I think the last rebase went wring, I see unrelated commits in the history. Can you fix that and we'll merge :)

@YushunXiang
Copy link
Contributor Author

Perfect! I think the last rebase went wring, I see unrelated commits in the history. Can you fix that and we'll merge :)

Sure, it is my fault. I used

git fetch upstream
git rebase upstream/main

caused something wrong.

@YushunXiang
Copy link
Contributor Author

YushunXiang commented Jun 11, 2025

Commit 67461fb
implement get_image_features method in Aria, Mistral3, and VipLlava models with updated parameters.

Let me explain the reason for this commit.

For example, in the file src/transformers/models/aria/modular_aria.py, AriaForConditionalGeneration inherits from LlavaForConditionalGeneration. However, the get_image_features interface in AriaModel is inconsistent with the get_image_features interface in LlavaForConditionalGeneration.

class AriaModel(LlavaModel):
    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        pixel_mask: Optional[torch.FloatTensor] = None,
        vision_feature_layer: int = -1,
    ):
        ...
class LlavaModel(LlavaPreTrainedModel):
    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        vision_feature_layer: Optional[Union[int, List[int]]] = None,
        vision_feature_select_strategy: Optional[str] = None,
        **kwargs,
    ):
        ...

If I don not add the method get_image_features in the class AriaForConditionalGeneration in file modular_aria.py, and then I use make fix-copies to generate the modeling files, will cause something wrong like:

class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        vision_feature_layer: Optional[Union[int, List[int]]] = None,
        vision_feature_select_strategy: Optional[str] = None,
        **kwargs,
    ):
        return self.model.get_image_features(
            pixel_values=pixel_values,
            vision_feature_layer=vision_feature_layer,
            vision_feature_select_strategy=vision_feature_select_strategy,
            **kwargs,
        )

@zucchini-nlp zucchini-nlp enabled auto-merge (squash) June 11, 2025 10:14
@zucchini-nlp zucchini-nlp merged commit 56a7cf5 into huggingface:main Jun 11, 2025
15 checks passed
@YushunXiang YushunXiang deleted the fix-paligemma branch June 11, 2025 11:35
lmarshall12 pushed a commit to lmarshall12/transformers that referenced this pull request Jun 12, 2025
…ation (huggingface#38730)

* fix: Add method to retrieve image features in PaliGemmaForConditionalGeneration

* feat: Add get_image_features method to multiple models for image feature extraction

* fix: reformat the files with ruff.

* feat: Add methods for packing and retrieving image and video features across multiple models

modified:
- modeling_chameleon.py
- modeling_llava_next.py
- modular_llava_next_video.py
- modeling_qwen2_vl.py

and generate the:
- modeling_llava_next_video.py
- modeling_llava_onevision.py
- modeling_qwen2_5_vl.py

* feat: Implement get_image_features method in Aria, Mistral3, and VipLlava models with updated parameters

* fix: reformatted the code with fix-style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants