Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d5909a7
add embedding getter
molbap Dec 2, 2025
a9eb634
modify your own logic
molbap Dec 2, 2025
b520cc7
a common test
molbap Dec 2, 2025
7ce45fe
some adapters are not PreTrainedModel s
molbap Dec 2, 2025
d41e204
few fixes
molbap Dec 2, 2025
0e93a61
implement correct-ish fix?
molbap Dec 2, 2025
de8ff71
fixup
molbap Dec 2, 2025
b2618b3
this is needed likely
molbap Dec 2, 2025
5d61150
woops
molbap Dec 3, 2025
ef55499
solving some cross-imports issues here and there
molbap Dec 3, 2025
44ab4c6
more ximports issues
molbap Dec 3, 2025
fe89c1c
finally
molbap Dec 3, 2025
2920d00
revert changes
molbap Dec 3, 2025
b8ccd0f
fixups
molbap Dec 3, 2025
b5ae5a6
improve message
molbap Dec 3, 2025
d209ff5
add common tests for input_ids first
molbap Dec 3, 2025
79665d4
increase test coverage
molbap Dec 3, 2025
844c707
Merge branch 'main' into fix_enable_grads_again
molbap Dec 4, 2025
fcc84a4
bigger update for GC
molbap Dec 4, 2025
e970fad
copies
molbap Dec 4, 2025
b4f5c15
mlcd is getting on my nerves a bit
molbap Dec 4, 2025
0246a70
ah yes
molbap Dec 4, 2025
81940dd
for BC
molbap Dec 4, 2025
284189a
break a couple modelings
molbap Dec 5, 2025
1079eef
Merge branch 'main' into fix_enable_grads_again
molbap Dec 5, 2025
f479598
simplify with base_model
molbap Dec 5, 2025
73b4f5d
fix copies for torch checkpointing
molbap Dec 5, 2025
0e7086f
simplify this model
molbap Dec 5, 2025
18d44ba
Merge branch 'main' into fix_enable_grads_again
molbap Dec 17, 2025
00cc669
improve messages
molbap Dec 17, 2025
d9d7442
Merge branch 'main' into fix_enable_grads_again
molbap Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_librosa_available,
is_mistral_common_available,
is_mlx_available,
is_numba_available,
is_pretty_midi_available,
)

Expand Down
82 changes: 47 additions & 35 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,54 +1063,52 @@ def get_input_embeddings(self) -> nn.Module:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""

# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
# for most NLP models), and if so, return it.

name = getattr(self, "_input_embed_layer", "embed_tokens")

# 1) Direct attribute (most NLP models).
if (default_embedding := getattr(self, name, None)) is not None:
return default_embedding
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
return getattr(self.embeddings, name)
# 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
if hasattr(self, "model") and hasattr(self.model, name):
return getattr(self.model, name)

if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
return self.model.embed_tokens
if hasattr(self, "base_model"):
base_model = self.base_model
if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()

# 3) vanilla decoder‑only architectures
elif hasattr(self, "embed_tokens"):
return self.embed_tokens
else:
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
"please override in the subclass."
)
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
)

def set_input_embeddings(self, value: nn.Module):
"""Fallback setter that handles **~70%** of models in the code-base.

Order of attempts:
1. `self.model.embed_tokens`
2. `self.embed_tokens`
3. delegate to the *base model* if one exists
4. otherwise raise `NotImplementedError` so subclasses still can (and
1. `self.<_input_embed_layer>` (direct attribute)
2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
3. `self.model.<_input_embed_layer>` (encoder/decoder models)
4. delegate to the *base model* if one exists
5. otherwise raise `NotImplementedError` so subclasses still can (and
should) override for exotic layouts.
"""

# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
name = getattr(self, "_input_embed_layer", "embed_tokens")
if hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 2) as well as vanilla decoder‑only architectures
elif hasattr(self, name):
# 1) Direct attribute (most NLP models)
if hasattr(self, name):
setattr(self, name, value)
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
elif getattr(self, self.base_model_prefix, self) is not self:
base_model = getattr(self, self.base_model_prefix, self)
base_model.set_input_embeddings(value)
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
setattr(self.embeddings, name, value)
# 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
elif hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
elif hasattr(self, "base_model") and self.base_model is not self:
self.base_model.set_input_embeddings(value)
else:
raise NotImplementedError(
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
Expand Down Expand Up @@ -2043,14 +2041,18 @@ def make_inputs_require_grads(module, input, output):

hooks = []
seen_modules = set()
found_embeddings = False

for module in self.modules():
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
continue

input_embeddings = module.get_input_embeddings()
try:
input_embeddings = module.get_input_embeddings()
except NotImplementedError:
continue
Comment on lines +2050 to +2053

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no simple way around this unfortunately

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oke, I think with the warning below, it is more explicit


if input_embeddings is None:
if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
continue

embedding_id = id(input_embeddings)
Expand All @@ -2059,11 +2061,18 @@ def make_inputs_require_grads(module, input, output):

seen_modules.add(embedding_id)
hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
found_embeddings = True

self._require_grads_hooks = hooks
if hooks:
# for BC
self._require_grads_hook = hooks[0]
if not found_embeddings:
logger.warning_once(
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
"support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
)

def disable_input_require_grads(self):
"""
Expand Down Expand Up @@ -3000,7 +3009,10 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)

if getattr(self, "_hf_peft_config_loaded", False):
needs_embedding_grads = self.main_input_name == "input_ids"
# we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
if enable_input_grads:
Comment on lines +3012 to +3015

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, for my understanding, why do we always need to enable grads when doing GC training with text models?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't always, but we do with reentrant checkpointing. IIUC it's not to actualy use these gradients, it's that torch.utils.checkpoint needs at least one input and one output to actually have gradients, else the checkpointed part will not have a gradient.

# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,9 +781,9 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
hidden_states,
attention_mask,
output_attentions,
**kwargs,
)

Expand Down Expand Up @@ -976,6 +976,7 @@ class AlignVisionModel(AlignPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = False
_input_embed_layer = "convolution"
_no_split_modules = ["AlignVisionBlock"]

def __init__(self, config: AlignVisionConfig):
Expand All @@ -995,9 +996,6 @@ def __init__(self, config: AlignVisionConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.convolution

@can_return_tuple
@auto_docstring
def forward(
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,9 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
hidden_states,
attention_mask,
output_attentions,
**kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/chinese_clip/modeling_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,9 +638,9 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
hidden_states,
attention_mask,
output_attentions,
**kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/clap/modeling_clap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,9 +1266,9 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
hidden_states,
attention_mask,
output_attentions,
**kwargs,
)

Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,9 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
)

embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
embeddings = embeddings.flatten(2).transpose(1, 2)

return embeddings, (patch_height, patch_width)
return embeddings

Comment thread
molbap marked this conversation as resolved.

# Based on timm implementation, which can be found here:
Expand Down Expand Up @@ -291,7 +290,7 @@ def forward(
bool_masked_pos: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size()

if bool_masked_pos is not None:
Expand All @@ -308,7 +307,7 @@ def forward(

embeddings = self.dropout(embeddings)

return embeddings, (patch_height, patch_width)
return embeddings


class InternVLVisionMLP(nn.Module):
Expand Down Expand Up @@ -455,7 +454,7 @@ def forward(
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)

encoder_outputs = self.encoder(embedding_output)
sequence_output = encoder_outputs[0]
Expand Down
14 changes: 5 additions & 9 deletions src/transformers/models/internvl/modular_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_int
from ...utils.generic import check_model_inputs
from ..clip.modeling_clip import CLIPMLP
from ..janus.modeling_janus import JanusVisionAttention
Expand All @@ -44,9 +44,6 @@
from .configuration_internvl import InternVLConfig, InternVLVisionConfig


logger = logging.get_logger(__name__)


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
Expand Down Expand Up @@ -177,10 +174,9 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
)

embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
embeddings = embeddings.flatten(2).transpose(1, 2)

return embeddings, (patch_height, patch_width)
return embeddings


# Based on timm implementation, which can be found here:
Expand Down Expand Up @@ -259,7 +255,7 @@ def forward(
bool_masked_pos: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size()

if bool_masked_pos is not None:
Expand All @@ -276,7 +272,7 @@ def forward(

embeddings = self.dropout(embeddings)

return embeddings, (patch_height, patch_width)
return embeddings


class InternVLVisionMLP(CLIPMLP):
Expand Down Expand Up @@ -412,7 +408,7 @@ def forward(
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)

encoder_outputs = self.encoder(embedding_output)
sequence_output = encoder_outputs[0]
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,9 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
hidden_states,
attention_mask,
output_attentions,
**kwargs,
)

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -984,6 +990,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1104,6 +1116,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down
Loading