Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d107f9e
simplify common get/set
molbap Jul 10, 2025
98e5467
remove some noise
molbap Jul 10, 2025
cc8b91c
change some 5 years old modeling utils
molbap Jul 10, 2025
077af3d
update examples
molbap Jul 10, 2025
7d9b16c
fix copies
molbap Jul 10, 2025
32c7c00
revert some changes
molbap Jul 10, 2025
33e0f06
fixes, gah
molbap Jul 10, 2025
91537aa
format
molbap Jul 10, 2025
57aac34
move to Mixin
molbap Jul 11, 2025
39322e6
remove smolvlm specific require grad
molbap Jul 11, 2025
951d269
Merge branch 'main' into draft_VLM_apis
molbap Jul 11, 2025
d2dd129
skip
molbap Jul 11, 2025
92e2296
force defaults
molbap Jul 17, 2025
00b676f
Merge branch 'main' into draft_VLM_apis
molbap Jul 17, 2025
c5a3676
remodularise some stuff
molbap Jul 17, 2025
1265140
remodularise more stuff
molbap Jul 17, 2025
b776481
add safety for audio models
molbap Jul 17, 2025
d9693a3
style
molbap Jul 17, 2025
4ab8399
have a correct fallback, you daft donkey
molbap Jul 17, 2025
2bee084
remove this argh
molbap Jul 17, 2025
f0fe775
change heuristic for audio models
molbap Jul 18, 2025
773a9ad
Merge branch 'main' into draft_VLM_apis
molbap Jul 18, 2025
a233db2
fixup
molbap Jul 18, 2025
d0fbe20
revert
molbap Jul 18, 2025
5436239
this works
molbap Jul 18, 2025
41221d8
revert again
molbap Jul 18, 2025
8556767
🧠
molbap Jul 18, 2025
21820d7
aaah ESM has two modelings aaah
molbap Jul 18, 2025
7b614f8
Merge branch 'main' into draft_VLM_apis
molbap Jul 18, 2025
429e119
Merge branch 'main' into draft_VLM_apis
molbap Jul 21, 2025
6e1fbdd
add informative but short comment
molbap Jul 21, 2025
807eaf0
Merge branch 'draft_VLM_apis' of github.com:huggingface/transformers …
molbap Jul 21, 2025
3edfa42
add `input_embed_layer` mixin attribute
molbap Jul 21, 2025
b1af42a
style
molbap Jul 21, 2025
8455e0f
Merge branch 'main' into draft_VLM_apis
molbap Jul 21, 2025
153a0b0
walrus has low precedence
molbap Jul 21, 2025
31404b4
modular fix
molbap Jul 21, 2025
822dd44
this was breaking parser
molbap Jul 21, 2025
395d3f9
Merge branch 'main' into draft_VLM_apis
molbap Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,6 @@ def __init__(self, config: MyNewModel2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down Expand Up @@ -434,12 +428,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down
6 changes: 0 additions & 6 deletions examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

Expand Down
6 changes: 0 additions & 6 deletions examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,6 @@ def __init__(self, config: SuperConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down
106 changes: 72 additions & 34 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,8 +1898,79 @@ def floating_point_ops(

return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)

class EmbeddingAccessMixin:
"""
Base utilities to regroup getters and setters for embeddings.
"""
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.

Returns:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
return self.model.embed_tokens

# 2) vanilla decoder‑only architectures
elif hasattr(self, "embed_tokens"):
return self.embed_tokens
elif getattr(self, self.base_model_prefix, self) is not self:
base_model = getattr(self, self.base_model_prefix, self)
return base_model.get_input_embeddings()
else:
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
"please override in the subclass."
)

def set_input_embeddings(self, value: nn.Module):
"""Fallback setter that handles **~70 %** of models in the code‑base.

Order of attempts:
1. `self.model.embed_tokens`
2. `self.embed_tokens`
3. delegate to the *base model* if one exists
4. otherwise raise `NotImplementedError` so subclasses still can (and
should) override for exotic layouts.
"""
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
self.model.embed_tokens = value

# 2) vanilla decoder‑only architectures
elif hasattr(self, "embed_tokens"):
self.embed_tokens = value

# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
elif getattr(self, self.base_model_prefix, self) is not self:
base_model = getattr(self, self.base_model_prefix, self)
base_model.set_input_embeddings(value)

else:
raise NotImplementedError(
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
)

def get_output_embeddings(self) -> nn.Module:
"""
Returns the model's output embedding, defaulting to lm_head.

Returns:
`nn.Module`: A torch module mapping hidden states to vocabulary.
"""

class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
return getattr(self, "lm_head", None)

def set_output_embeddings(self, new_embeddings):
"""
Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
"""
if getattr(self, "lm_head"):
self.lm_head = new_embeddings

class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.

Expand Down Expand Up @@ -2736,40 +2807,7 @@ def disable_input_require_grads(self):
"""
self._require_grads_hook.remove()

def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.

Returns:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError

def set_input_embeddings(self, value: nn.Module):
"""
Set model's input embeddings.

Args:
value (`nn.Module`): A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError

def get_output_embeddings(self) -> nn.Module:
"""
Returns the model's output embeddings.

Returns:
`nn.Module`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings

def _init_weights(self, module):
"""
Expand Down
30 changes: 0 additions & 30 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,6 @@ def __init__(self, config: ArceeConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down Expand Up @@ -440,18 +434,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model = decoder

Expand Down Expand Up @@ -535,12 +517,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down Expand Up @@ -687,12 +663,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down
21 changes: 0 additions & 21 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,12 +743,6 @@ def __init__(self, config: AriaTextConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down Expand Up @@ -825,18 +819,6 @@ def __init__(self, config: AriaTextConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model = decoder

Expand Down Expand Up @@ -1149,9 +1131,6 @@ def set_input_embeddings(self, value):
def get_output_embeddings(self) -> nn.Module:
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,6 @@ def set_input_embeddings(self, value):
def get_output_embeddings(self) -> nn.Module:
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

Expand Down
18 changes: 0 additions & 18 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,12 +1084,6 @@ def __init__(self, config: BambaConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down Expand Up @@ -1334,18 +1328,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model = decoder

Expand Down
6 changes: 0 additions & 6 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,12 +855,6 @@ def __init__(self, config: BambaConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down
24 changes: 0 additions & 24 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,12 +769,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -935,12 +929,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -1380,12 +1368,6 @@ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def _tie_weights(self):
if self.config.tie_word_embeddings:
self.model._tie_weights()
Expand Down Expand Up @@ -1880,12 +1862,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model.decoder = decoder

Expand Down
Loading