-
Notifications
You must be signed in to change notification settings - Fork 32.4k
Untangle config inheritance #41541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ArthurZucker
merged 65 commits into
huggingface:main
from
zucchini-nlp:config-inheritance
Jan 16, 2026
Merged
Untangle config inheritance #41541
Changes from all commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
a3a8726
remove from base
zucchini-nlp 7569f17
delete
zucchini-nlp 83db459
fetcher fix
zucchini-nlp 66225ab
missing values
zucchini-nlp 6d322fa
update
zucchini-nlp cd1c645
is decoder missing
zucchini-nlp a744da3
forgot to add
zucchini-nlp 212e609
add special tokens with default `None` in text models
zucchini-nlp c45264c
fsmt has unused subconfig, fix it!
zucchini-nlp 589a776
update
zucchini-nlp cb03af1
Merge branch 'main' into config-inheritance
zucchini-nlp a7ea9dc
Merge branch 'main' into config-inheritance
zucchini-nlp 9520541
fix
zucchini-nlp a94cd75
add missig token id defaults
zucchini-nlp 338558c
fix more tests
zucchini-nlp 0e6f6f7
tie_word_embeddings
zucchini-nlp fb0c58d
tiny fixes
zucchini-nlp 05699a7
more test fixes
zucchini-nlp a913528
fix docstrings
zucchini-nlp 87e610d
fix copies
zucchini-nlp ad1930e
fix style?
zucchini-nlp 57b1736
rebase main
zucchini-nlp 74a4a46
fix copied again
zucchini-nlp 9afe474
merge main
zucchini-nlp f79588e
fix copies
zucchini-nlp 7d3c3cf
fix examples
zucchini-nlp 6565152
delete left over print stmt
zucchini-nlp 69efa1f
Merge branch 'main' into config-inheritance
zucchini-nlp d909306
splitnter
zucchini-nlp 796b312
Merge branch 'main' into config-inheritance
zucchini-nlp 3d01b44
this defi will fix a bunch decoder-only models
zucchini-nlp d696e05
it's gonna be so much fun to fix issues after refactors on main...
zucchini-nlp 9c87fd5
make style
zucchini-nlp 3a41439
fix copies
zucchini-nlp c831852
WTF, I rebased 5 min ago?!
zucchini-nlp 83bc532
not all models are supposed to have an attr for `tie_word_embeddings`!
zucchini-nlp 1be91a1
merge main
zucchini-nlp f034540
comment out
zucchini-nlp 5f803ff
fix
zucchini-nlp 4ef5d93
more fixes
zucchini-nlp 73d8d24
fix copies
zucchini-nlp 137493b
docstring and non-model tests
zucchini-nlp c1f0aae
update
zucchini-nlp b8ed5b3
fix repo consistency
zucchini-nlp be62176
merge main
zucchini-nlp e3333fb
style
zucchini-nlp 2fba81a
fix
zucchini-nlp 2cc234a
Merge branch 'main' into config-inheritance
zucchini-nlp 68545e7
revert
zucchini-nlp 233c986
Merge branch 'main' into config-inheritance
zucchini-nlp 2fd964b
remove unused attr
zucchini-nlp 66dd842
fix repo
zucchini-nlp fbe85de
fix test
zucchini-nlp 1391a5e
Merge branch 'main' into config-inheritance
zucchini-nlp f76536e
fix a few tests, more tests
zucchini-nlp 4fdf142
fix gemma & llava
zucchini-nlp d046c0f
style
zucchini-nlp 6c6e720
gemma3n also
zucchini-nlp 9fe2176
Merge branch 'main commit c0d2e26f' into config-inheritance
ydshieh 2d4da5f
merge main
zucchini-nlp 840e8ea
new models as well
zucchini-nlp 5cc58bd
skip the test
zucchini-nlp 9909482
Merge branch 'main' into config-inheritance
zucchini-nlp b2c7337
Merge branch 'main' into config-inheritance
zucchini-nlp 0e9d3d2
Merge branch 'main' into config-inheritance
zucchini-nlp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -910,7 +910,7 @@ def get_extended_attention_mask( | |
| # Provided a padding mask of dimensions [batch_size, seq_length] | ||
| # - if the model is a decoder, apply a causal mask in addition to the padding mask | ||
| # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
| if self.config.is_decoder: | ||
| if getattr(self.config, "is_decoder", None): | ||
| extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( | ||
| input_shape, attention_mask | ||
| ) | ||
|
|
@@ -2392,7 +2392,10 @@ def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict: | |
|
|
||
| tied_mapping = self._tied_weights_keys | ||
| # If the config does not specify any tying, return empty dict | ||
| if not self.config.tie_word_embeddings: | ||
| # NOTE: not all modules have `tie_word_embeddings` attr, for example vision-only | ||
| # modules do not have any word embeddings! | ||
| tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False) | ||
| if not tie_word_embeddings: | ||
|
Comment on lines
+2395
to
+2398
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. important! Please review here. There are two big changes:
|
||
| return {} | ||
| # If None, return empty dict | ||
| elif tied_mapping is None: | ||
|
|
@@ -2642,10 +2645,7 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean | |
| new_num_tokens = new_embeddings.weight.shape[0] | ||
|
|
||
| # if word embeddings are not tied, make sure that lm head is resized as well | ||
| if ( | ||
| self.get_output_embeddings() is not None | ||
| and not self.config.get_text_config(decoder=True).tie_word_embeddings | ||
| ): | ||
| if self.get_output_embeddings() is not None: | ||
| old_lm_head = self.get_output_embeddings() | ||
| if isinstance(old_lm_head, torch.nn.Embedding): | ||
| new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) | ||
|
|
@@ -4302,15 +4302,17 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): | |
|
|
||
| # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an | ||
| # attention_mask or not. In this case, we should still show a warning because this is a rare case. | ||
| # NOTE: `sep_token_id` is not used in all models and it can be absent in the config | ||
| sep_token_id = getattr(self.config, "sep_token_id", None) | ||
| if ( | ||
| (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) | ||
| or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) | ||
| or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) | ||
| or (sep_token_id is not None and sep_token_id == self.config.pad_token_id) | ||
| ): | ||
| warn_string += ( | ||
| f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " | ||
| f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " | ||
| f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." | ||
| f"or the `sep_token_id` ({sep_token_id}), and your input is not padded." | ||
| ) | ||
|
|
||
| logger.warning_once(warn_string) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am deleting
tie_encoder_decoder. Can't find any model on the hub that uses it and the attr was added only for the customEncoderDecoderModelclass. However we can't tie encoder to decoder inEncoderDecoderModelanymore, no matter what is the value of this attribute