Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d107f9e
simplify common get/set
molbap Jul 10, 2025
98e5467
remove some noise
molbap Jul 10, 2025
cc8b91c
change some 5 years old modeling utils
molbap Jul 10, 2025
077af3d
update examples
molbap Jul 10, 2025
7d9b16c
fix copies
molbap Jul 10, 2025
32c7c00
revert some changes
molbap Jul 10, 2025
33e0f06
fixes, gah
molbap Jul 10, 2025
91537aa
format
molbap Jul 10, 2025
57aac34
move to Mixin
molbap Jul 11, 2025
39322e6
remove smolvlm specific require grad
molbap Jul 11, 2025
951d269
Merge branch 'main' into draft_VLM_apis
molbap Jul 11, 2025
d2dd129
skip
molbap Jul 11, 2025
92e2296
force defaults
molbap Jul 17, 2025
00b676f
Merge branch 'main' into draft_VLM_apis
molbap Jul 17, 2025
c5a3676
remodularise some stuff
molbap Jul 17, 2025
1265140
remodularise more stuff
molbap Jul 17, 2025
b776481
add safety for audio models
molbap Jul 17, 2025
d9693a3
style
molbap Jul 17, 2025
4ab8399
have a correct fallback, you daft donkey
molbap Jul 17, 2025
2bee084
remove this argh
molbap Jul 17, 2025
f0fe775
change heuristic for audio models
molbap Jul 18, 2025
773a9ad
Merge branch 'main' into draft_VLM_apis
molbap Jul 18, 2025
a233db2
fixup
molbap Jul 18, 2025
d0fbe20
revert
molbap Jul 18, 2025
5436239
this works
molbap Jul 18, 2025
41221d8
revert again
molbap Jul 18, 2025
8556767
🧠
molbap Jul 18, 2025
21820d7
aaah ESM has two modelings aaah
molbap Jul 18, 2025
7b614f8
Merge branch 'main' into draft_VLM_apis
molbap Jul 18, 2025
429e119
Merge branch 'main' into draft_VLM_apis
molbap Jul 21, 2025
6e1fbdd
add informative but short comment
molbap Jul 21, 2025
807eaf0
Merge branch 'draft_VLM_apis' of github.com:huggingface/transformers …
molbap Jul 21, 2025
3edfa42
add `input_embed_layer` mixin attribute
molbap Jul 21, 2025
b1af42a
style
molbap Jul 21, 2025
8455e0f
Merge branch 'main' into draft_VLM_apis
molbap Jul 21, 2025
153a0b0
walrus has low precedence
molbap Jul 21, 2025
31404b4
modular fix
molbap Jul 21, 2025
822dd44
this was breaking parser
molbap Jul 21, 2025
395d3f9
Merge branch 'main' into draft_VLM_apis
molbap Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,6 @@ def __init__(self, config: MyNewModel2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down Expand Up @@ -433,12 +427,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down
6 changes: 0 additions & 6 deletions examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

Expand Down
6 changes: 0 additions & 6 deletions examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,6 @@ def __init__(self, config: SuperConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down
224 changes: 188 additions & 36 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,7 +1902,97 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
class EmbeddingAccessMixin:
"""
Base utilities to regroup getters and setters for embeddings.
Introduces the `input_layer_embed` attribute, which indicates
where the input embeddings come from and where they
should be set.
"""

_input_embed_layer = "embed_tokens" # default layer that holds input embeddings.

def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.

Returns:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""

# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
# for most NLP models), and if so, return it.

name = getattr(self, "_input_embed_layer", "embed_tokens")

if (default_embedding := getattr(self, name, None)) is not None:
return default_embedding
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`

if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
return self.model.embed_tokens

# 3) vanilla decoder‑only architectures
elif hasattr(self, "embed_tokens"):
return self.embed_tokens
else:
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
"please override in the subclass."
)

def set_input_embeddings(self, value: nn.Module):
"""Fallback setter that handles **~70 %** of models in the code‑base.

Order of attempts:
1. `self.model.embed_tokens`
2. `self.embed_tokens`
3. delegate to the *base model* if one exists
4. otherwise raise `NotImplementedError` so subclasses still can (and
should) override for exotic layouts.
"""

# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
name = getattr(self, "_input_embed_layer", "embed_tokens")
if hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 2) as well as vanilla decoder‑only architectures
elif hasattr(self, name):
setattr(self, name, value)
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
elif getattr(self, self.base_model_prefix, self) is not self:
base_model = getattr(self, self.base_model_prefix, self)
base_model.set_input_embeddings(value)
else:
raise NotImplementedError(
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
)

def get_output_embeddings(self):
if not hasattr(self, "lm_head"):
return None
try:
# Speech / vision backbones raise here, so we return None.
# Legit use of get_input_embs?
self.get_input_embeddings()
except NotImplementedError:
return None
return self.lm_head

def set_output_embeddings(self, new_embeddings):
"""
Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
"""
if getattr(self, "lm_head"):
self.lm_head = new_embeddings


class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.

Expand Down Expand Up @@ -2004,6 +2094,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
_supports_attention_backend = False
_can_record_outputs = None

# This attribute sets the default parameter to be

@property
@torch._dynamo.allow_in_graph
def can_record_outputs(self) -> dict[str, OutputRecorder]:
Expand Down Expand Up @@ -2267,6 +2359,101 @@ def _from_config(cls, config, **kwargs):

return model

@classmethod
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
"""
Checks that the requested attention implementation exists and tries to get the kernel from hub
if `attn_implementation` matches hf kernels pattern.
"""
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")

# Extract repo_id and kernel_name from the string
repo_id, kernel_name = attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()

try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
if (
not isinstance(attn_implementation, dict)
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
message += (
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
)
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
raise ValueError(message + ".")

return attn_implementation

def set_attention_implementation(self, attn_implementation: Union[str, dict]):
"""
Checks and dispatches to the requested attention implementation.
"""
requested_attn_implementation = self._check_attn_implementation(attn_implementation)

# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
# See https://github.com/huggingface/transformers/pull/32238
for key in self.config.sub_configs.keys():
sub_config = getattr(self.config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
# For models with backbone sub-config might be not initialized. Set the requested att
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
if (
sub_config is not None
and sub_config._attn_implementation_internal is None
and curr_attn_implementation is not None
):
sub_config._attn_implementation_internal = curr_attn_implementation

if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
self.config._attn_implementation = "flash_attention_3"
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
self.config._attn_implementation = "flash_attention_2"
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
self.config._attn_implementation = "flex_attention"
elif (
requested_attn_implementation in [None, "sdpa"]
and not is_torch_xla_available()
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
):
self.config._attn_implementation = "sdpa"
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
self.config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
self.config._attn_implementation = requested_attn_implementation.get("", None)
else:
self.config._attn_implementation = "eager"

self.config._attn_implementation_autoset = True

@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
Expand Down Expand Up @@ -2769,41 +2956,6 @@ def disable_input_require_grads(self):
"""
self._require_grads_hook.remove()

def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.

Returns:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError

def set_input_embeddings(self, value: nn.Module):
"""
Set model's input embeddings.

Args:
value (`nn.Module`): A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError

def get_output_embeddings(self) -> nn.Module:
"""
Returns the model's output embeddings.

Returns:
`nn.Module`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings

def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class and is
Expand Down
30 changes: 0 additions & 30 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,6 @@ def __init__(self, config: ArceeConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down Expand Up @@ -438,18 +432,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model = decoder

Expand Down Expand Up @@ -533,12 +515,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down Expand Up @@ -685,12 +661,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

@can_return_tuple
@auto_docstring
def forward(
Expand Down
Loading