Skip to content

[BC] Update get_(text|image|audio|video)_features methods to return BaseModelOutputWithPooling#42564

Merged
ArthurZucker merged 143 commits intohuggingface:mainfrom
tomaarsen:feat/normalize_get_features_methods
Jan 23, 2026
Merged

[BC] Update get_(text|image|audio|video)_features methods to return BaseModelOutputWithPooling#42564
ArthurZucker merged 143 commits intohuggingface:mainfrom
tomaarsen:feat/normalize_get_features_methods

Conversation

@tomaarsen
Copy link
Member

@tomaarsen tomaarsen commented Dec 2, 2025

What does this PR do?

  • Add return_dict to get_text_features, get_image_features, get_audio_features, get_video_features methods to return 'BaseModelOutputWithPooling' by default

Fixes #42401

The architectures supporting get_image_features are all extremely different, with wildly different outputs for the get_image_features methods:

  • 2d outputs,
  • 3d outputs,
  • lists of 2d outputs (due to non-matching shapes),
  • existing 'return_attentions' resulting in returning 2-tuple,
  • existing 'return_dict' resulting in returning 3-tuples (???),
  • high quality image embeddings,
  • low quality image embeddings,
  • deepstack image embeddings,
  • etc. etc. etc.

This PR aims to normalize these, so that users and third party libraries can perform inference on individual modalities (e.g. via get_image_features) despite the full model requiring multiple modalities.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp @ArthurZucker @Cyrilvallez

  • Tom Aarsen

…ModelOutputWithPooling'

Added to all architectures except blip-2, which has a much different structure here. It uses 'Blip2TextModelWithProjection' to get these embeddings/features, but this class isn't as simple to use
…eModelOutputWithPooling'

Well, the architectures supporting get_image_features are all extremely different, with wildly different outputs for the get_image_features methods. 2d outputs, 3d outputs, lists of 2d outputs (due to non-matching shapes), existing 'return_attentions' resulting in returning 2-tuple, existing 'return_dict' resulting in returning 3-tuples (???), high quality image embeddings, low quality image embeddings, deepstack image embeddings, etc. etc. etc.

And I only went through like 70-80% of all architectures with get_image_features before I gave up.

Standardisation of all of these sounds like a lost cause.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this internally and decided to add last_hidden_states to all models as the last state from vision block. The pooled embeddings will stay of different shapes as is

For the last hidden state the shapes are already more standardized, with a few major options. The only special cases might be qwen-like models where each image encoding has different sequence length and thus the outputs are concatenated as length*dim

@tomaarsen
Copy link
Member Author

The initial work on all 4 modalities is done, with a handful of exceptions. There's about 2 or 3 breaking architectures, specifically architectures that already supported return_dict and return_attentions. Typings, docstrings, and tests still have to be added, but I'm curious if this has a chance of being merged before I continue with those.

  • Tom Aarsen

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes, I see there are a few tricky models that do not fit neatly with BaseModelOutput

To wrap it up, to msake this work firstly we need to ensure that all vision encoders are capable of returning dict in the way that PreTrainedModels do, i.e. by checking config,return_dict and returning attentions, hidden states, pooled output etc. Then we can ask get_image_features to return the same dict which was output by an encoder (optionally pooled output is updated in VLMs). That will preserve all fields of the vision encoder output

I think the current state of the PR is already doing it with a few non-standard models. I left comments under those models so lmk if that makes sense

…model_inputs

The changes in check_model_inputs aren't the clearest/prettiest, but they work well for now.
@tomaarsen
Copy link
Member Author

I've pushed a proposal in 9a251ce that takes this in a bit of a different direction by adopting the modern TransformersKwargs and check_model_inputs. I updated the latter to allow setting the pooler_output as the default, unless the user explicitly uses return_dict=True (which returns a ModelOutput subclass) or return_dict=None (which uses the model config's return_dict to determine whether to output a ModelOutput or the pooled embeddings).

I can extend this to more architectures, but want to get your view on this first.

Usage:

from transformers import AutoModel, AutoProcessor
from transformers.image_utils import load_image
import torch

model = AutoModel.from_pretrained("openai/clip-vit-large-patch14", attn_implementation="eager")
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = load_image(url)
image_inputs = processor(images=image, return_tensors="pt")
text_inputs = processor(text=["a photo of a cat"], return_tensors="pt")
joint_inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt")

def print_output(output):
    if isinstance(output, torch.Tensor):
        print("Output is a tensor with shape:", output.shape)
    else:
        print("Output is a ModelOutput with attributes:")
        for key, value in output.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: tensor with shape {value.shape}")
            else:
                print(f"  {key}: {type(value)}")
    print()

with torch.inference_mode():
    image_features = model.get_image_features(**image_inputs)
    print("model.get_image_features(**image_inputs) outputs:")
    print_output(image_features)

    image_features = model.get_image_features(**image_inputs, return_dict=True)
    print("model.get_image_features(**image_inputs, return_dict=True) outputs:")
    print_output(image_features)

    image_features = model.get_image_features(**image_inputs, return_dict=True, output_hidden_states=True, output_attentions=True)
    print("model.get_image_features(**image_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:")
    print_output(image_features)

    text_features = model.get_text_features(**text_inputs)
    print("model.get_text_features(**text_inputs) outputs:")
    print_output(text_features)

    text_features = model.get_text_features(**text_inputs, return_dict=True)
    print("model.get_text_features(**text_inputs, return_dict=True) outputs:")
    print_output(text_features)

    text_features = model.get_text_features(**text_inputs, return_dict=True, output_hidden_states=True, output_attentions=True)
    print("model.get_text_features(**text_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:")
    print_output(text_features)

Outputs:

model.get_image_features(**image_inputs) outputs:
Output is a tensor with shape: torch.Size([1, 768])

model.get_image_features(**image_inputs, return_dict=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 257, 1024])
  pooler_output: tensor with shape torch.Size([1, 768])

model.get_image_features(**image_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 257, 1024])
  pooler_output: tensor with shape torch.Size([1, 768])
  hidden_states: <class 'tuple'>
  attentions: <class 'tuple'>

model.get_text_features(**text_inputs) outputs:
Output is a tensor with shape: torch.Size([1, 768])

model.get_text_features(**text_inputs, return_dict=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 7, 768])
  pooler_output: tensor with shape torch.Size([1, 768])

model.get_text_features(**text_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 7, 768])
  pooler_output: tensor with shape torch.Size([1, 768])
  hidden_states: <class 'tuple'>
  attentions: <class 'tuple'>
  • Tom Aarsen

….._features methods

This commit updates all get_text_features methods, even blip_2, which was previously not yet attempted
A handful of outliers that aren't updated yet, e.g. if there's 2+ ModelOutput classes that are viable, or the vq-based ones

For context, the other modeling file classes haven't been updated with the new get_..._features format, nor have the tests
@tomaarsen
Copy link
Member Author

tomaarsen commented Dec 16, 2025

For context, these are the TODOs at this point:

  • Unfinished architectures
    • fuyu get_image_features: I don't think Fuyu has a real Vision Encoder beyond just a single Linear
    • blip_2 get_image_features: The new format misses the query_outputs/qformer_outputs, should use a new ModelOutput subclass somehow.
    • instructblip get_image_features: The new format misses the query_outputs/qformer_outputs, should use a new ModelOutput subclass somehow.
    • instructblipvideo get_video_features: See above
    • kosmos2 get_image_features: The new format misses the projection_attentions, should use a new ModelOutput subclass somehow.
    • ovis2 get_image_features: The new format misses the visual_indicator_features, should use a new ModelOutput subclass somehow.
    • deepseek_vl_hybrid get_image_features: This method produces both low_res_vision_encodings and high_res_vision_encodings, should use a new ModelOutput subclass somehow to combine them.
    • chameleon get_image_features: Update the VQVAE class to output the hidden states before quantization.
    • emu3 get_image_features: Update the VQVAE class to output the hidden states before quantization.
  • Update all architecture classes to accept the new output format
  • Add and/or update tests for the new output format
  • Update docstrings
  • Update type hints

  • Tom Aarsen

The Fuyu architecture doesn't have an image encoder:
> Architecturally, Fuyu is a vanilla decoder-only transformer - there is no image encoder.
@tomaarsen
Copy link
Member Author

tomaarsen commented Jan 22, 2026

I'm writing up the docs now, and I'm just very unhappy with the outputs. For example, the pooler_output has so many different formats, e.g. this is just for images:
- (batch_size, hidden_size)
- (batch_size, 1, hidden_size)
- (batch_size, num_patches, hidden_size)
- a tuple of (num_patches, hidden_size)
- (batch_size, height, width, hidden_size)
- (batch_size * num_patches, hidden_size)
- a list of (num_patches, hidden_size)
- (batch_size * num_frames, hidden_size)

Only get_text_features has some consistency with (batch_size, hidden_size), image is just a total nightmare, and audio and video are inconsistent as well. If the goal is easier downstream usage (i.e. more black-boxy), then this PR is barely helping. But post-processing the last_hidden_states/pooler_outputs so that they match across all architectures isn't exactly viable either.

Sentence Transformers has always been able to treat transformers as a black box, because the text architectures are so consistently integrated: last_hidden_states is (batch_size, sequence_length, hidden_size) and then that can be pooled into (batch_size, hidden_size).
Even with this PR merged, projects like Sentence Transformers would still have to create per-architecture edge cases to handle the weird formats used, and that kind of maintenance nightmare is exactly what I'd like to avoid. I think third parties won't be very interested in these methods because they're inconsistent in their output format.

I think I've been fighting for image/audio/video model normalization in a battle that was already lost 2-3 years ago. I don't know if we'll ever be able to black box-ify those architectures.

I think we can still move forward with this PR, as the updated methods on some of the architectures are still valuable for Sentence Transformers, but I'll simply have to write edge cases per architecture.
I'll just hope that if future e.g. multimodal embedding models require monomodal inputs (e.g. a text only search query, or an audio only search query), that the forward in the transformers integration is designed to support that. For example via apply_chat_template always forcing the existence of input_ids and having all other modalities as optional (e.g. qwen3_vl). That would solve a lot of my problems.

But I really don't think that these methods are the future of inferencing with monomodal inputs on multimodal models, and I don't know if we should introduce the generalized variants that @ArthurZucker brought up. I think the qwen3_vl-route is much more promising. So perhaps we should ditch the docs part, not make the feature bigger than it warrants, and aim for that "modality-optional" path like qwen3_vl instead. It feels much cleaner.

I can work on the migration guide.

  • Tom Aarsen

@tomaarsen
Copy link
Member Author

Migration guide is ready. I don't feel like we should aim for docs here, the output formats are in my opinion too halfbaked to warrant advertising this as a feature apart from the API Reference.

  • Tom Aarsen

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My last few nits ~

I'm fine with keeping this in the migration guide only. It is a bit sad but we should establish some kind of standard, e.g. like you mentioned qwen3 vl, in the future. While older models will be harder to fix for now, we can keep it in mind to not further "diversify" if possible.

"""
class_file = sys.modules[cls.__module__].__file__
with open(class_file, "r") as f:
with open(class_file, "r", encoding="utf8") as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Windows failing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Python on Windows is notorious for requiring encoding="utf8" when opening files if the file might contain characters that can't be mapped to anything:

FAILED tests/models/aya_vision/test_modeling_aya_vision.py::AyaVisionModelTest::test_get_image_features_attentions - UnicodeDecodeError: 'charmap' codec can't decode byte 0x8d in position 17271: character maps to <undefined>
FAILED tests/models/chinese_clip/test_modeling_chinese_clip.py::ChineseCLIPModelTest::test_get_image_features_attentions - UnicodeDecodeError: 'charmap' codec can't decode byte 0x9d in position 43871: character maps to <undefined>
FAILED tests/models/chinese_clip/test_modeling_chinese_clip.py::ChineseCLIPModelTest::test_get_text_features_attentions - UnicodeDecodeError: 'charmap' codec can't decode byte 0x9d in position 43871: character maps to <undefined>
    def decode(self, input, final=False):
>       return codecs.charmap_decode(input,self.errors,decoding_table)[0]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       UnicodeDecodeError: 'charmap' codec can't decode byte 0x9d in position 43871: character maps to <undefined>

In short, if the modeling class file has e.g. Chinese characters, then you won't be able to use model.set_attn_implementation, not ideal.

I can definitely move this to a separate PR if you prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nicer yes tbh. A bit unrelated to this PR, let's keep a clean history

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR in #43433, would be very nice to include in v5.

Comment on lines +4708 to +4709
def _video_features_get_expected_num_hidden_states(self, model_tester=None):
return self._video_features_get_expected_num_attentions(model_tester) + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea no worries, the PR is already very big as it is. We can refactor this later on as well

@tomaarsen
Copy link
Member Author

I believe this PR is good to go now. I know we're looking to merge this and loads of other PRs today, so I'll try to stay on top of merge conflicts. Ping me here or on Slack if there's any.

cc @ArthurZucker

  • Tom Aarsen

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42564&sha=d8e786

@ArthurZucker ArthurZucker merged commit 55dadb8 into huggingface:main Jan 23, 2026
15 of 25 checks passed
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
…rn `BaseModelOutputWithPooling` (huggingface#42564)

* Add return_dict to get_text_features methods to allow returning 'BaseModelOutputWithPooling'

Added to all architectures except blip-2, which has a much different structure here. It uses 'Blip2TextModelWithProjection' to get these embeddings/features, but this class isn't as simple to use

* Add return_dict to get_image_features methods to allow returning 'BaseModelOutputWithPooling'

Well, the architectures supporting get_image_features are all extremely different, with wildly different outputs for the get_image_features methods. 2d outputs, 3d outputs, lists of 2d outputs (due to non-matching shapes), existing 'return_attentions' resulting in returning 2-tuple, existing 'return_dict' resulting in returning 3-tuples (???), high quality image embeddings, low quality image embeddings, deepstack image embeddings, etc. etc. etc.

And I only went through like 70-80% of all architectures with get_image_features before I gave up.

Standardisation of all of these sounds like a lost cause.

* make fixup

* Ignore discrepancies for pooler_output, focus on last_hidden_state

* Update get_image_features for the missing architectures

* Update all get_audio_features

* Update get_video_features, except instructblipvideo

Should be fine though, as that  'get_video_features' doesn't live on the AutoModel class, but the AutoModelForConditionalGeneration class

* Run ruff formatting

* Patch Glm4v VisionModel forward with BaseModelOutputWithPooling

* Patch instructblip, although backwards incompatibility stands

* Patch Kosmos2 and Ovis2

* Reformat Ovis2

* Avoid now-deprecated return_attentions

* Remove NumFrames

* Proposal to simplify get_..._features via TransformersKwargs & check_model_inputs

The changes in check_model_inputs aren't the clearest/prettiest, but they work well for now.

* Revert check_model_inputs, adopt can_return_tuple, accept BC on get_..._features methods

This commit updates all get_text_features methods, even blip_2, which was previously not yet attempted

* Fix typo: can_return_dict -> can_return_tuple

* Adopt can_return_tuple for many get_image_features

A handful of outliers that aren't updated yet, e.g. if there's 2+ ModelOutput classes that are viable, or the vq-based ones

For context, the other modeling file classes haven't been updated with the new get_..._features format, nor have the tests

* Update all get_audio_features, some edge cases handled (e.g. gemma3n)

* Update most get_video_features,  some edge case remain, e.g. instructblipvideo

* Patch Fuyu, just return BaseModelOutputWithPooling without pooler

The Fuyu architecture doesn't have an image encoder:
> Architecturally, Fuyu is a vanilla decoder-only transformer - there is no image encoder.

* Introduce ModelOutput subclass for Chameleon, patch get_image_features

* Update modeling files with new output formats for get_..._features

* Update fast_vlm modeling forward from modular llava to remove image_sizes

* Update colqwen2 its self.vlm.model.visual call to expect BaseModelOutput

* Replace prior return_dict with check_model_inputs on qwen2_5_vl its VisionTransformer

* Use BaseModelOutputWithProjectionAttentions for Kosmos2 to allow returning the projection attentions

* Update Emu akin to Chameleon

* Update the blip architectures with a naive fix

A better solution might be to remove the qformer etc. calls from the get_image/video_features and run those separately in the forward methods.

* Convert remaining modulars (emu3, janus), patch emu3

* Patch blip test

* Update deepseek_vl using a new BaseModelOutputWithHighResVisionEncodings

* Remove 'copied' for blip_2, instructblip and kosmos2 as they required custom changes

* Patch qwen3_vl and qwen3_vl_moe, where I used last_hidden_state instead of pooler_output

* Run repo-consistency

* Use kwargs["output_hidden_states"] = True to hardcode output_hidden_states where needed

* Update new GlmAsr get_audio_features on ForConditionalGeneration

* Run make style

* Try to add _can_record_outputs to florence2

* Override JanusVisionModel.forward to avoid bad q-former copy from Blip2

* Import missing BaseModelOutput

* Pop deprecated 'return_attentions', setting 'return_dict' won't be useful iiuc

* Reintroduce kwargs filtering in llava etc. for safety re. image_sizes

We also don't need to incorporate code cleanup etc. in this PR, we should keep it as minimal as possible and leave these kinds of lines intact.

* Use BaseModelOutputWithPooling superclass consistently for custom get_..._features outputs

* Update Blip-2 family and its BaseModelOutputWithVisionQformerOutputs

To use both a vision_outputs and qformer_outputs as keys in the BaseModelOutputWithPooling subclass, despite some duplication.

* Update glm4v _can_record_outputs

* Remove check_model_inputs in granite_speech

I could also use can_return_tuple, but this might be problematic if `return_dict=False` in the config

* Run make style

* Add _can_record_outputs to Ovis2VisionModel

* Update get_text_features/get_video_features from pe_video

* Update missing case on sam3

* Update get_text_features type hints to Union[tuple, BaseModelOutputWithPooling]

Blip-2 and Clvp are the only exceptions

* Add _can_record_inputs to qwen2_5_omni and qwen2_5_vl

* Update get_image_features and get_video_features on ernie4_5_vl_moe

Can we even use BaseModelOutputWithPooling for these? It's a MoE model

* Update get_image_features type hints to Union[tuple, BaseModelOutputWithPooling]

With a handful of exceptions

* Remove @auto_docstring from pe_video, it's seemingly not used on that arch

(or well documented)

* Update get_video_features type hints to Union[tuple, BaseModelOutputWithPooling]

Only exceptions for BaseModelOutputWithDeepstackFeatures

* Fix pe_video import issue

* Update forward, test, and docstring for sam3

* Update get_audio_features type hints to Union[tuple, BaseModelOutputWithPooling]

Also update BaseModelOutput to BaseModelOutputWithPooling in several places, leaving room for a potential pooled embedding to be computed by get_audio_features

* Add simple test case for get_text_features

Fails on CLIP, MetaCLIP, Siglip, Siglip2 as they use 'self.text_model = text_model.text_model', bypassing the TextModel that has `check_model_inputs` cc @zucchini-nlp related to huggingface#42564

* First attempt to get get_image_features under test, still 26 failures

* Resolve several test failures, progress still slow and inconsistent

* Split up get_..._features tests more, should be simpler to disable/customize specific parts per arch

* Fix emu3 tests, also track non-temporal ResNet in hidden_states

* Patch chameleon, emu3, ernie4_5, janus

* Skip output_attentions for FastVLM, timm doesn't accept it

But I'm not sure how to handle the output_hidden_states case

* Patch groupvit, instructblip, ovis2

plus style

* Patch paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, and skip test for perception_lm

perception_lm is still problematic with output_hidden_states, akin to fast_vlm

* Patch qwen3_omni_moe, sam family, edgetam

P.s. edgetam had incorrect _can_record_outputs

Now, all issues that remain with get_image_features are due to 1) CLIP family issue and 2) unclarity with expected output_hidden_states for timm-based models

* Kill now unused BaseModelOutputWithFeatureMaps

* Remove left-over return_dict from prior attempt

* Allow for output_hidden_states in theory, but skip impossible tests

The tests are failing as edgetam doesn't output hidden_states. It used to, because of a broken TimmWrapper in _can_return_outputs.

* Introduce tests for get_audio_features, fixed all architectures

* Introduce tests for get_video_features, only ernie4_5_vl_moe is failing

It's failing as the split_sizes gets made too small, such that the video_embeds doesn't sum to the split_sizes anymore. I'm not sure how to best tackle it.

I also removed the get_video_features from PaddleOCR_vl, as I don't think it's meant to be used with video

* Call post_init on GraniteSpeechCTCEncoder, which was given a PreTrainedModel subclass

* Update llava_onevision test suite, only create video pixel_values in new method

Instead of in the common one, as that negatively affects other tests (as there's no video tokens in the inputs_ids then)

* Create custom video input for ernie4_5_vl_moe

* Skip CLIP family tests; they don't support output_hidden_states/output_attentions due to bug

* Breaking: update Blip2Model.get_text_features to no longer output logits

* Satisfy test_num_layers_is_small test for align

* Test against last_hidden_state against batch_size and hidden_size

19 failures, mostly if architectures merge the first dimension with e.g. num_frames for videos, or swap dimensions from the norm with the hidden_state at index 1 in a 4d-tensor

I don't think it's reasonable to expect these to be 'fixed', they would require drastic changes in the architectures or somewhat arbitrary changes in the post-processing of the hidden states.

* Skip last_hidden_state shape tests for unusual cases

E.g. when batch_size is merged with num_frames or num_patches, or hidden_size is in index -3 instead of index -1

* Update docstrings via auto_docstring for all get_..._features methods

Also add to e.g. aria.md to ensure that get_..._features methods are documented

* Ensure all auto_doc arguments are documented

* Remove redundant docstrings

* Also patch the new glm_image for get_image_features/output_hidden_states

* Update modular files as per check_docstring rules ...

... to avoid modular/check_docstring conflicts. Modular would propargate changes from modular to modeling files, and then check_docstring would complain and update the modeling files only. This created an unstable state where one of the two scripts was unhappy. I resolved this by manually tracking down the check_docstring issues in the modular files.

* Update glm-image dates via fix-repo

* FloatTensor -> LongTensor for image_tokens

* Add simple last_hidden_state description, fix output typing of Gemma3nAudioEncoder.forward

* Add missing `-> tuple | BaseModel...` on check_model_inputs

Using ``check_model_inputs[^\n]*\n\s*def forward\([^\)]*\):``

* Ensure forward typing with check_model_inputs is `-> tuple | BaseModel...`

Using ``check_model_inputs[^\n]*\n\s*def forward\([^\)]+\) -> (?!tuple | )``

* Undo accidental rename of Ovis2VisionAttention

* Fix incorrect type hints for blip family

* Patch get_image_features for lighton_ocr

* Explicitly use Ovis2VisionAttention in Ovis2VisionEncoderLayer in modular

* Update use of get_image_features for lighton_ocr

Forgot to run tests to verify that it worked, oops

* Rerun python utils/add_dates.py

Not sure which script removed the date... :/

* Remove tie_last_hidden_states=False from check_model_inputs from ...

forward methods that previously did not return a BaseModelOutput

* Revert accidental metaclip import change

* Add missing return_dict=True in get_..._features methods

* Add `output_hidden_states=True` in InternVL get_image_features

Only if needed

* Add missing docstring for llava_next_video get_video_features

* Quick clean-up in _video_features_prepare_config_and_inputs test helper

* model.set_attn_implementation instead of config._attn_implementation

Note:  There's about ~10 other places that use config._attn_implementation in this test file alone

* Add simple docstring to some helper methods re. inputs.

It's not extremely useful I think, as it has to be somewhat generic due to the large differences in the architectures

* Explain why get_..._features test inputs are overridden

* Undo incorrect return_dict=True change in deepseek_vl_hybrid

I added return_dict to get_low_res_image_features and get_high_res_image_features calls, but these methods already set return_dict automatically

* Revert accidental metaclip import change

* Adopt **vision_outputs in instructblip, but mess remains

* Avoid kwargs["output_hidden_states"] = True in get_..._features methods

* Update check_model_inputs to default vision args based on config

* Unrelated but important: patch set_attn_implementation for Windows

idem with set_experts_implementation

* Revert output_hidden_states changes on InternVL

On this architecture, it seems cleaner to go the `kwargs["output_hidden_states"] = True` route, as a simple `output_hidden_states=vision_feature_layer != -1` prevents setting the `output_hidden_states` to True if requested for downstream use.

* Extend d9001cc (check_model_inputs); remove more vision_feature_layer defaulting

* Patch unusual bug: llava_next_video used self.vision_feature_layer

Doesn't seem like this was being used elsewhere, so I can just update it to use the local variant like elsewhere

* Add unused use_cache to TimmWrapperModel to patch FastVLM

FastVLM now forwards this argument due to the check_model_inputs, and TimmWrapper can't use it

* Update check_config_attributes to allow for vision attributes

And rerun fix-repo

* Add tests for config.return_dict=False

Also; siglip had "nested" check_model_inputs: the VisionModel and VisionTransformer (below it) both used `check_model_inputs`. This means that the VisionModel.forward eats the 'return_dict=True', and the lower VisionTransformer.forward its `check_model_inputs` uses the config.return_dict=False to turn the output to a tuple.

The siglip/clip/metaclip family is still broken due to the `text_model = text_model.text_model` bypassing the class with the `check_model_inputs`.

* permute and quantize separately for the comment

* Ditch shared custom_args for ernie4_5_vl_moe

* Move Ernie4_5_VL_MoeVisionAttention next to VisionBlock

* Add missing "attentions" from Florence2 _can_record_outputs

* Clarify kwargs.get("image_sizes") in modeling_llava

* Remove commented skip_test_image_features_output_shape in chameleon tests

* Add a migration guide under 'Library-wide changes with lesser impact'

* Parameterize get_..._features tests  with return_dict (True, False, None)

* Add comment re. TimmWrapper _can_record_outputs

* Shrink Gemma3nAudioEncoderModelOutput with auto_docstring & superclass

* Revert "Unrelated but important: patch set_attn_implementation for Windows"

This reverts commit 0923216.
@Mecoli1219
Copy link

Mecoli1219 commented Feb 10, 2026

Hi @tomaarsen, I’m testing the Llama 4 implementation in transformers (v5) and encountered the following error during the forward pass of Llama4ForConditionalGeneration:

AttributeError: 'BaseModelOutput' object has no attribute 'pooler_output'

After digging into the source, I suspect this might be related to this PR. It appears that Llama4ForConditionalGeneration now expects get_image_features to return a pooler_output, but the underlying Llama4VisionModel (or the get_image_features method itself) still returns a standard BaseModelOutput.

Does this align with the intended architectural changes for Llama 4, or should get_image_features be updated to handle the issue safely?

@tomaarsen
Copy link
Member Author

@Mecoli1219 Thank you for reporting this, it seems like this is an unintended regression, my apologies. Ideally, the Llama4VisionModel should be updated to return BaseModelOutputWithPooling, I suspect.

  • Tom Aarsen

aslonnie added a commit to ray-project/ray that referenced this pull request Feb 26, 2026
regression caused due to upgrade in transformer lib version,
specifically there is a behavior change caused by
huggingface/transformers#42564

Signed-off-by: abrar <abrar@anyscale.com>
Co-authored-by: Lonnie Liu <95255098+aslonnie@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

The get_(text|image|audio|video)_features methods have inconsistent output formats, needs aligning for Sentence Transformers

6 participants