Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d5909a7
add embedding getter
molbap Dec 2, 2025
a9eb634
modify your own logic
molbap Dec 2, 2025
b520cc7
a common test
molbap Dec 2, 2025
7ce45fe
some adapters are not PreTrainedModel s
molbap Dec 2, 2025
d41e204
few fixes
molbap Dec 2, 2025
0e93a61
implement correct-ish fix?
molbap Dec 2, 2025
de8ff71
fixup
molbap Dec 2, 2025
b2618b3
this is needed likely
molbap Dec 2, 2025
5d61150
woops
molbap Dec 3, 2025
ef55499
solving some cross-imports issues here and there
molbap Dec 3, 2025
44ab4c6
more ximports issues
molbap Dec 3, 2025
fe89c1c
finally
molbap Dec 3, 2025
2920d00
revert changes
molbap Dec 3, 2025
b8ccd0f
fixups
molbap Dec 3, 2025
b5ae5a6
improve message
molbap Dec 3, 2025
d209ff5
add common tests for input_ids first
molbap Dec 3, 2025
79665d4
increase test coverage
molbap Dec 3, 2025
844c707
Merge branch 'main' into fix_enable_grads_again
molbap Dec 4, 2025
fcc84a4
bigger update for GC
molbap Dec 4, 2025
e970fad
copies
molbap Dec 4, 2025
b4f5c15
mlcd is getting on my nerves a bit
molbap Dec 4, 2025
0246a70
ah yes
molbap Dec 4, 2025
81940dd
for BC
molbap Dec 4, 2025
284189a
break a couple modelings
molbap Dec 5, 2025
1079eef
Merge branch 'main' into fix_enable_grads_again
molbap Dec 5, 2025
f479598
simplify with base_model
molbap Dec 5, 2025
73b4f5d
fix copies for torch checkpointing
molbap Dec 5, 2025
0e7086f
simplify this model
molbap Dec 5, 2025
18d44ba
Merge branch 'main' into fix_enable_grads_again
molbap Dec 17, 2025
00cc669
improve messages
molbap Dec 17, 2025
d9d7442
Merge branch 'main' into fix_enable_grads_again
molbap Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_librosa_available,
is_mistral_common_available,
is_mlx_available,
is_numba_available,
is_pretty_midi_available,
)

Expand Down
65 changes: 34 additions & 31 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,51 +969,51 @@ def get_input_embeddings(self) -> nn.Module:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""

# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
# for most NLP models), and if so, return it.

name = getattr(self, "_input_embed_layer", "embed_tokens")

# 1) Direct attribute (most NLP models).
if (default_embedding := getattr(self, name, None)) is not None:
return default_embedding
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
return getattr(self.embeddings, name)
# 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
if hasattr(self, "model") and hasattr(self.model, name):
return getattr(self.model, name)

if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
return self.model.embed_tokens
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.base_model property has the same functionality

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true!

if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()

# 3) vanilla decoder‑only architectures
elif hasattr(self, "embed_tokens"):
return self.embed_tokens
else:
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
"please override in the subclass."
)
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
)

def set_input_embeddings(self, value: nn.Module):
"""Fallback setter that handles **~70%** of models in the code-base.

Order of attempts:
1. `self.model.embed_tokens`
2. `self.embed_tokens`
3. delegate to the *base model* if one exists
4. otherwise raise `NotImplementedError` so subclasses still can (and
1. `self.<_input_embed_layer>` (direct attribute)
2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
3. `self.model.<_input_embed_layer>` (encoder/decoder models)
4. delegate to the *base model* if one exists
5. otherwise raise `NotImplementedError` so subclasses still can (and
should) override for exotic layouts.
"""

# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
name = getattr(self, "_input_embed_layer", "embed_tokens")
if hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 2) as well as vanilla decoder‑only architectures
elif hasattr(self, name):
# 1) Direct attribute (most NLP models)
if hasattr(self, name):
setattr(self, name, value)
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
setattr(self.embeddings, name, value)
# 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
elif hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
elif getattr(self, self.base_model_prefix, self) is not self:
base_model = getattr(self, self.base_model_prefix, self)
base_model.set_input_embeddings(value)
Expand Down Expand Up @@ -1983,9 +1983,12 @@ def make_inputs_require_grads(module, input, output):
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
continue

input_embeddings = module.get_input_embeddings()
try:
input_embeddings = module.get_input_embeddings()
except NotImplementedError:
continue
Comment on lines +2050 to +2053

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no simple way around this unfortunately

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oke, I think with the warning below, it is more explicit


if input_embeddings is None:
if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
continue

embedding_id = id(input_embeddings)
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ class AlignVisionModel(AlignPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = False
_input_embed_layer = "convolution"

def __init__(self, config: AlignVisionConfig):
super().__init__(config)
Expand All @@ -994,9 +995,6 @@ def __init__(self, config: AlignVisionConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.convolution

@can_return_tuple
@auto_docstring
def forward(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/fast_vlm/modeling_fast_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, image_features):
class FastVlmPreTrainedModel(PreTrainedModel):
config: FastVlmConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -982,6 +988,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1101,6 +1113,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/poolformer/modeling_poolformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def __init__(self, config):
self.post_init()

def get_input_embeddings(self):
return self.embeddings.patch_embeddings
# Input embeddings correspond to the very first patch-embedding stage.
return self.encoder.patch_embeddings[0]

def set_input_embeddings(self, value):
self.encoder.patch_embeddings[0] = value

@auto_docstring
def forward(
Expand Down Expand Up @@ -332,6 +336,12 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.poolformer.get_input_embeddings()

def set_input_embeddings(self, value):
self.poolformer.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,11 @@ def forward(
return BaseModelOutput(last_hidden_state=hidden_states)


class SiglipTextTransformer(nn.Module):
class SiglipTextTransformer(SiglipPreTrainedModel):
_input_embed_layer = "token_embedding"

def __init__(self, config: SiglipTextConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(config)
Expand Down Expand Up @@ -614,6 +616,7 @@ def forward(


class SiglipVisionTransformer(SiglipPreTrainedModel):
_input_embed_layer = "patch_embedding"
_can_record_outputs = {
"hidden_states": SiglipEncoderLayer,
"attentions": SiglipAttention,
Expand Down Expand Up @@ -774,6 +777,12 @@ def __init__(self, config: SiglipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding

def set_input_embeddings(self, value: nn.Module):
self.text_model.embeddings.token_embedding = value

@filter_out_non_signature_kwargs()
@auto_docstring
def get_text_features(
Expand Down Expand Up @@ -969,6 +978,12 @@ def __init__(self, config: SiglipConfig) -> None:
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

def set_input_embeddings(self, value: nn.Module):
self.vision_model.embeddings.patch_embedding = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/models/siglip2/modeling_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def forward(


class Siglip2VisionTransformer(Siglip2PreTrainedModel):
_input_embed_layer = "patch_embedding"
_can_record_outputs = {
"hidden_states": Siglip2EncoderLayer,
"attentions": Siglip2Attention,
Expand Down Expand Up @@ -588,9 +589,11 @@ def forward(
return embeddings


class Siglip2TextTransformer(nn.Module):
class Siglip2TextTransformer(Siglip2PreTrainedModel):
_input_embed_layer = "token_embedding"

def __init__(self, config: Siglip2TextConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2TextEmbeddings(config)
Expand Down Expand Up @@ -831,6 +834,12 @@ def __init__(self, config: Siglip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding

def set_input_embeddings(self, value: nn.Module):
self.text_model.embeddings.token_embedding = value

@filter_out_non_signature_kwargs()
@auto_docstring
def get_text_features(
Expand Down Expand Up @@ -1048,6 +1057,12 @@ def __init__(self, config: Siglip2Config) -> None:
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

def set_input_embeddings(self, value: nn.Module):
self.vision_model.embeddings.patch_embedding = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
_input_embed_layer = "shared"

def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
Expand All @@ -932,9 +933,6 @@ def __init__(self, config: SwitchTransformersConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
_input_embed_layer = "shared"

def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
Expand All @@ -688,9 +689,6 @@ def __init__(self, config: SwitchTransformersConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def _timm_model_supports_gradient_checkpointing(self):
def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
self.timm_model.set_grad_checkpointing(enable)

def get_input_embeddings(self):
# TIMM backbones operate directly on images and do not expose token embeddings.
return None

def set_input_embeddings(self, value):
raise NotImplementedError("TimmWrapper models do not own token embeddings and cannot set them.")


class TimmWrapperModel(TimmWrapperPreTrainedModel):
"""
Expand All @@ -150,13 +157,6 @@ def __init__(self, config: TimmWrapperConfig):
self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
self.post_init()

def get_input_embeddings(self):
# Vision backbones from timm do not expose token embeddings, so there is nothing to return.
return None

def set_input_embeddings(self, value):
raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.")

@auto_docstring
def forward(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
is_mistral_common_available,
is_natten_available,
is_nltk_available,
is_numba_available,
is_onnx_available,
is_openai_available,
is_optimum_available,
Expand Down Expand Up @@ -1386,6 +1387,13 @@ def require_pyctcdecode(test_case):
return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)


def require_numba(test_case):
"""
Decorator marking a test that requires numba
"""
return unittest.skipUnless(is_numba_available(), "test requires numba")(test_case)


def require_librosa(test_case):
"""
Decorator marking a test that requires librosa
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
is_ninja_available,
is_nltk_available,
is_num2words_available,
is_numba_available,
is_onnx_available,
is_openai_available,
is_optimum_available,
Expand Down
Loading