-
Notifications
You must be signed in to change notification settings - Fork 33.7k
Widen match condition for _can_record_outputs
#43762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||||||||
| # limitations under the License. | ||||||||||
| import copy | ||||||||||
| import glob | ||||||||||
| import importlib | ||||||||||
| import json | ||||||||||
| import os | ||||||||||
| import os.path | ||||||||||
|
|
@@ -34,6 +35,7 @@ | |||||||||
| from parameterized import parameterized | ||||||||||
| from pytest import mark | ||||||||||
|
|
||||||||||
| import transformers.models.clip.modeling_clip as modeling_clip | ||||||||||
| from transformers import ( | ||||||||||
| AutoConfig, | ||||||||||
| AutoModel, | ||||||||||
|
|
@@ -43,6 +45,8 @@ | |||||||||
| BartForConditionalGeneration, | ||||||||||
| BartModel, | ||||||||||
| CLIPTextModelWithProjection, | ||||||||||
| CLIPVisionConfig, | ||||||||||
| CLIPVisionModel, | ||||||||||
| DynamicCache, | ||||||||||
| GPT2Config, | ||||||||||
| GPT2LMHeadModel, | ||||||||||
|
|
@@ -3735,3 +3739,29 @@ def test_vision_language_model(self): | |||||||||
| assert image_encoder is model.model.vision_tower, ( | ||||||||||
| f"LLaVA get_encoder(modality='image') should return vision_tower, got {type(image_encoder)}" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @require_torch | ||||||||||
| class TestCheckModelInputsReload(unittest.TestCase): | ||||||||||
| # See https://github.com/linkedin/Liger-Kernel/pull/1061 and | ||||||||||
| # https://github.com/huggingface/transformers/issues/43761 | ||||||||||
| def test_hidden_states_after_module_reload(self): | ||||||||||
| config = CLIPVisionConfig( | ||||||||||
| hidden_size=32, | ||||||||||
| intermediate_size=64, | ||||||||||
| num_hidden_layers=2, | ||||||||||
| num_attention_heads=4, | ||||||||||
| image_size=30, | ||||||||||
| patch_size=10, | ||||||||||
| ) | ||||||||||
| pixel_values = torch.randn(1, 3, 30, 30, device=torch_device) | ||||||||||
|
|
||||||||||
| model = CLIPVisionModel(config).to(torch_device) | ||||||||||
| outputs = model(pixel_values=pixel_values, output_hidden_states=True) | ||||||||||
| self.assertIsNotNone(outputs.hidden_states) | ||||||||||
|
|
||||||||||
| importlib.reload(modeling_clip) | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can move it under transformers/tests/test_modeling_common.py Lines 1754 to 1757 in ace7c37
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I hesitated, I didn't want to add yet another test for 400+ models 🫣 but you may be right
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha right! |
||||||||||
|
|
||||||||||
| model_after_reload = CLIPVisionModel(config).to(torch_device) | ||||||||||
| outputs_after_reload = model_after_reload(pixel_values=pixel_values, output_hidden_states=True) | ||||||||||
| self.assertIsNotNone(outputs_after_reload.hidden_states) | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line of check for
class_namewas added for AyaVision, but I think it's not needed anymore. At least tests don't complain 😄There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aah :D makes sense, wdyt of the added check though? not sure how to handle reloading more elegantly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's the best we can do with
importlib.reloadand shouldn't backfire with false positives. Lemme approve, forgot to