Skip to content

Widen match condition for _can_record_outputs#43762

Closed
molbap wants to merge 2 commits intomainfrom
fix_recording_on_reload
Closed

Widen match condition for _can_record_outputs#43762
molbap wants to merge 2 commits intomainfrom
fix_recording_on_reload

Conversation

@molbap
Copy link
Copy Markdown
Contributor

@molbap molbap commented Feb 5, 2026

In case of module reloading, we currently lose tracking hooks for hidden states and attentions . Widening the matching condition a bit.

Should fix #43761, also mentioned in linkedin/Liger-Kernel#1061

Reproducer for CLIP:

import importlib, torch
from transformers import CLIPVisionConfig, CLIPVisionModel
import transformers.models.clip.modeling_clip as modeling_clip

config = CLIPVisionConfig(hidden_size=32, intermediate_size=64, num_hidden_layers=2, num_attention_heads=4, image_size=30, patch_size=10)
pixel_values = torch.randn(1, 3, 30, 30)

m1 = CLIPVisionModel(config)
o1 = m1(pixel_values=pixel_values, output_hidden_states=True)
print("before reload:", o1.hidden_states is None)

importlib.reload(modeling_clip)

m2 = CLIPVisionModel(config)
o2 = m2(pixel_values=pixel_values, output_hidden_states=True)
print("after reload:", o2.hidden_states is None)

Second reload will currently fail, and this is a fix proposal.

Added a test as well (specific to CLIP, likely enough)

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

for key, specs in capture_tasks:
# The second check is for multimodals where only backbone layer suffix is available
if (specs.target_class is not None and isinstance(module, specs.target_class)) or (
specs.class_name is not None and name.endswith(specs.class_name)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line of check for class_name was added for AyaVision, but I think it's not needed anymore. At least tests don't complain 😄

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aah :D makes sense, wdyt of the added check though? not sure how to handle reloading more elegantly

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's the best we can do with importlib.reload and shouldn't backfire with false positives. Lemme approve, forgot to

outputs = model(pixel_values=pixel_values, output_hidden_states=True)
self.assertIsNotNone(outputs.hidden_states)

importlib.reload(modeling_clip)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can move it under test_modeling_common.py, for make smth similar to this test and reload the module at some point before checking if config.output_attentions works?

def test_attention_outputs(self):
if not self.has_attentions:
self.skipTest(reason="Model does not output attentions")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I hesitated, I didn't want to add yet another test for 400+ models 🫣 but you may be right

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha right!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch actually! 🤗

@molbap
Copy link
Copy Markdown
Contributor Author

molbap commented Feb 6, 2026

I want to think that great minds think alike 👀 @Cyrilvallez 's PR solves more problems with check_model_inputs #43765

@molbap molbap closed this Feb 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[v5 regression] CLIPVisionModel.forward returns hidden_states=None even when output_hidden_states=True

4 participants