Skip to content

Fix CLIPVisionModel hidden_states issue in Llava convergence test for transformers v5#1061

Merged
Tcc0403 merged 1 commit intolinkedin:transformers-5.0.0rc1from
yukiu00:fix-llava
Feb 7, 2026
Merged

Fix CLIPVisionModel hidden_states issue in Llava convergence test for transformers v5#1061
Tcc0403 merged 1 commit intolinkedin:transformers-5.0.0rc1from
yukiu00:fix-llava

Conversation

@yukiu00
Copy link
Copy Markdown
Contributor

@yukiu00 yukiu00 commented Feb 3, 2026

Fixes #1011

Summary

  • Fix TypeError: 'NoneType' object is not subscriptable error in Llava convergence test caused by image_outputs.hidden_states returning None
  • Remove unnecessary importlib.reload(modeling_clip) from revert_liger_kernel_to_llava() function

Problem

In transformers v5, calling importlib.reload(modeling_clip) breaks CLIPVisionModel's output_hidden_states functionality. When the Llava convergence test runs:

  1. First run (without Liger): creates model, runs test, then calls revert_liger_kernel_to_llava() which reloads modeling_clip
  2. Second run (with Liger): creates new model, but CLIPVisionModel.forward() now returns hidden_states=None even when output_hidden_states=True is passed

This causes the error at:

selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
# TypeError: 'NoneType' object is not subscriptable

Solution

Remove importlib.reload(modeling_clip) from revert_liger_kernel_to_llava(). This is safe because:

  • Liger kernel does not patch modeling_clip when model=None (which is the case in convergence tests)
  • Only modeling_llava and modeling_llama need to be reloaded to revert Liger patches

Test plan

  • python -m pytest test/convergence/bf16/test_mini_models_multimodal.py -k llava passes
  • Verified hidden_states is no longer None after the fix

Remove importlib.reload(modeling_clip) from revert_liger_kernel_to_llava()
as it breaks CLIPVisionModel's output_hidden_states functionality in
transformers v5. Liger kernel does not patch modeling_clip when model=None,
so reloading it is unnecessary and causes hidden_states to return None.
@yukiu00 yukiu00 marked this pull request as ready for review February 3, 2026 16:37
Copy link
Copy Markdown
Collaborator

@Mecoli1219 Mecoli1219 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. btw, do we need to reload llama? I think liger kernel doesn't patch modeling_llama here.
@Tcc0403 I've tested it already. Feel free to approve and merge it.

@yukiu00
Copy link
Copy Markdown
Contributor Author

yukiu00 commented Feb 5, 2026

@Mecoli1219

Yes, we need to reload modeling_llama. In the Llava convergence test, apply_liger_kernel_to_llama(**kwargs) is called before apply_liger_kernel_to_llava(). This patches modeling_llama's apply_rotary_pos_emb, LlamaRMSNorm, LlamaMLP, and LlamaForCausalLM.forward, so the reload is necessary to revert those patches.
Only modeling_clip is safe to remove from the reload list, as it's not patched when model=None.

@Tcc0403
Copy link
Copy Markdown
Collaborator

Tcc0403 commented Feb 5, 2026

LGTM. btw, do we need to reload llama? I think liger kernel doesn't patch modeling_llama here.

Yes as @yukiu00 explained above, there's a llama patch inside llava patch.

Thanks for the fix, but do you know the root cause of CLIPVisionModel.forward() returning hidden_states=None after reloaded?

Second run (with Liger): creates new model, but CLIPVisionModel.forward() now returns hidden_states=None even when output_hidden_states=True is passed

@Tcc0403
Copy link
Copy Markdown
Collaborator

Tcc0403 commented Feb 5, 2026

I just checked transformers changes on ClipEncoder

v5.0.0
https://github.com/huggingface/transformers/blob/08810b1e278938278c50153ee1edfd7a20a759da/src/transformers/models/clip/modeling_clip.py#L485

v4.57.6
https://github.com/huggingface/transformers/blob/753d61104116eefc8ffc977327b441ee0c8d599f/src/transformers/models/clip/modeling_clip.py#L506

Somehow output_hidden_states argument is removed from .forward() in v5. Would you like to report it to transformers team and ask whether it is intended change?

update: output_hidden_states is in TransformersKwargs fields but never used after the modernizing PR.

pivot PR huggingface/transformers#41546

@yukiu00
Copy link
Copy Markdown
Contributor Author

yukiu00 commented Feb 5, 2026

Thanks for the deep dive into the root cause! I've just opened an issue on the transformers repository to report this regression.

Here is the link: huggingface/transformers#43761

I referenced this PR and the context you provided. Hopefully, they can clarify if this was intended or a bug.

@yukiu00
Copy link
Copy Markdown
Contributor Author

yukiu00 commented Feb 5, 2026

FYI: huggingface/transformers#43762

@yukiu00
Copy link
Copy Markdown
Contributor Author

yukiu00 commented Feb 5, 2026

@Tcc0403 FYI: I tested the transformers fix branch (fix_recording_on_reload) and confirmed it resolves the root cause. Both the minimal reproduction code and Llava convergence test pass now.

@Tcc0403
Copy link
Copy Markdown
Collaborator

Tcc0403 commented Feb 7, 2026

Can you check whether huggingface/transformers#43765 resolves the issue?

@yukiu00
Copy link
Copy Markdown
Contributor Author

yukiu00 commented Feb 7, 2026

@Tcc0403 Confirmed. transformers main (5.2.0.dev0, includes #43765) resolves the issue.

$ python -m pytest test/convergence/bf16/test_mini_models_multimodal.py -k llava -xvs
1 passed, 11 deselected in 18.55s

@Tcc0403
Copy link
Copy Markdown
Collaborator

Tcc0403 commented Feb 7, 2026

Thanks!

@Tcc0403 Tcc0403 merged commit 576a2e3 into linkedin:transformers-5.0.0rc1 Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants