Skip to content

[fix] Use last_hidden_state key from get_image_features for llama4#43882

Merged
vasqu merged 1 commit intohuggingface:mainfrom
tomaarsen:fix/llama4
Feb 10, 2026
Merged

[fix] Use last_hidden_state key from get_image_features for llama4#43882
vasqu merged 1 commit intohuggingface:mainfrom
tomaarsen:fix/llama4

Conversation

@tomaarsen
Copy link
Member

What does this PR do?

Resolves #42564 (comment)
#42564 updated get_image_features for Llama4, but it erroneously started using pooler_output instead of the previous last_hidden_state. Additionally, the Llama4VisionModel output has been updated to BaseModelOutputWithPooling to match many vision models.

Reproducer

import torch

from transformers import AutoProcessor, Llama4ForConditionalGeneration

model_id = "hf-internal-testing/tiny-random-llama4"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    attn_implementation="sdpa",  # flex attention / flash_attention_2 do not work, debugging...
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# Fix meta tensor: convert to float32 with random weights
if model.vision_model.rotary_embedding.freqs_ci.is_meta:
    shape = model.vision_model.rotary_embedding.freqs_ci.shape
    # Note: Ideally should compute proper RoPE frequencies, but using random as requested
    model.vision_model.rotary_embedding.freqs_ci = torch.randn(*shape, dtype=torch.float32, device=model.device)

url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": url1},
            {"type": "image", "url": url2},
            {"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
        ]
    },
]

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=32,
)

response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
print(outputs[0])

On main:

Traceback (most recent call last):
  File "c:\code\transformers\demo_llama4.py", line 41, in <module>
    outputs = model.generate(
              ^^^^^^^^^^^^^^^
  File "C:\Users\tom\.conda\envs\transformers\Lib\site-packages\torch\utils\_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\code\transformers\src\transformers\generation\utils.py", line 2638, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "C:\code\transformers\src\transformers\generation\utils.py", line 2833, in _sample
    outputs = self._prefill(input_ids, generation_config, model_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\code\transformers\src\transformers\generation\utils.py", line 3822, in _prefill
    return self(**model_inputs, return_dict=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom\.conda\envs\transformers\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom\.conda\envs\transformers\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\code\transformers\src\transformers\utils\generic.py", line 1016, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\code\transformers\src\transformers\models\llama4\modeling_llama4.py", line 1327, in forward
    ).pooler_output
      ^^^^^^^^^^^^^
AttributeError: 'BaseModelOutput' object has no attribute 'pooler_output'

On this PR:

广场 meno बतздрав 공기грамمعграмمعграмمعграмمعграмمعграмمعграмمعграмمعграмمعграмمعграмمعграмمعграмمعграм
tensor([200000, 200005,   1556,  ...,  96359,  20938,  96359], device='cuda:0')

(Gibberish obviously as this is a tiny-random model, but it works now!)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp @vasqu

Thank you to @Mecoli1219 for reporting this.

  • Tom Aarsen

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: llama4

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, do we not have llama4 fast tests running to check this?

Comment on lines -1169 to 1171
return BaseModelOutput(
return BaseModelOutputWithPooling(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the last hidden state is the one after layernorm_post, and pooler state is after adapter. Though it'll be a breaking change...
Fine with leaving it as is, thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right :/ Damn, I wish I spotted that before v5, as it's indeed breaking if we improve it. I did make roughly this change for a few other architectures pre-v5.

@tomaarsen
Copy link
Member Author

Btw, do we not have llama4 fast tests running to check this?

No, I was surprised as well. We only have slow tests.

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 10, 2026

super weird tbh. The model arch is pretty straightforward so there should be no reason to skip the ModelCommonTest mixin

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yea we don't have fast tests for llama4 (which I also wondered about but I think it was done in a rush so it happens)

@ArthurZucker
Copy link
Collaborator

Yep we had an insane rush it was the first BIG BIG model 😢

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 10, 2026

I see, would be great to add it but I don;;t want to block this PR. Will add in my todo notes, maybe one day 🙃

@vasqu vasqu merged commit 71d7fc5 into huggingface:main Feb 10, 2026
25 checks passed
Tcc0403 added a commit to linkedin/Liger-Kernel that referenced this pull request Feb 10, 2026
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Fix Convergence Tests for Transformers v5, including updating some
mismatched variables, increasing the tolerance for some tests, ...etc.
The only one error remained is expected and should be fixed in
[Transformers#43882](huggingface/transformers#43882).
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Related Issues & PRs
- #978 
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: H100
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence on v4 & v5

---------

Co-authored-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
jiosephlee pushed a commit to jiosephlee/transformers_latest that referenced this pull request Feb 11, 2026
…ma4 (huggingface#43882)

Use last_hidden_state key from get_image_features for llama4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants