@@ -1902,7 +1902,97 @@ def floating_point_ops(
19021902 return 6 * self .estimate_tokens (input_dict ) * self .num_parameters (exclude_embeddings = exclude_embeddings )
19031903
19041904
1905- class PreTrainedModel (nn .Module , ModuleUtilsMixin , PushToHubMixin , PeftAdapterMixin ):
1905+ class EmbeddingAccessMixin :
1906+ """
1907+ Base utilities to regroup getters and setters for embeddings.
1908+ Introduces the `input_layer_embed` attribute, which indicates
1909+ where the input embeddings come from and where they
1910+ should be set.
1911+ """
1912+
1913+ _input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
1914+
1915+ def get_input_embeddings (self ) -> nn .Module :
1916+ """
1917+ Returns the model's input embeddings.
1918+
1919+ Returns:
1920+ `nn.Module`: A torch module mapping vocabulary to hidden states.
1921+ """
1922+
1923+ # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
1924+ # for most NLP models), and if so, return it.
1925+
1926+ name = getattr (self , "_input_embed_layer" , "embed_tokens" )
1927+
1928+ if (default_embedding := getattr (self , name , None )) is not None :
1929+ return default_embedding
1930+ # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1931+
1932+ if hasattr (self , "model" ) and hasattr (self .model , "embed_tokens" ):
1933+ return self .model .embed_tokens
1934+
1935+ # 3) vanilla decoder‑only architectures
1936+ elif hasattr (self , "embed_tokens" ):
1937+ return self .embed_tokens
1938+ else :
1939+ base_model = getattr (self , "base_model_prefix" , None )
1940+ if base_model is not None :
1941+ base_model = getattr (self , base_model , None )
1942+ if base_model is not None and base_model is not self :
1943+ return base_model .get_input_embeddings ()
1944+ raise NotImplementedError (
1945+ f"`get_input_embeddings` not auto‑handled for { self .__class__ .__name__ } ; "
1946+ "please override in the subclass."
1947+ )
1948+
1949+ def set_input_embeddings (self , value : nn .Module ):
1950+ """Fallback setter that handles **~70 %** of models in the code‑base.
1951+
1952+ Order of attempts:
1953+ 1. `self.model.embed_tokens`
1954+ 2. `self.embed_tokens`
1955+ 3. delegate to the *base model* if one exists
1956+ 4. otherwise raise `NotImplementedError` so subclasses still can (and
1957+ should) override for exotic layouts.
1958+ """
1959+
1960+ # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1961+ name = getattr (self , "_input_embed_layer" , "embed_tokens" )
1962+ if hasattr (self , "model" ) and hasattr (self .model , name ):
1963+ setattr (self .model , name , value )
1964+ # 2) as well as vanilla decoder‑only architectures
1965+ elif hasattr (self , name ):
1966+ setattr (self , name , value )
1967+ # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
1968+ elif getattr (self , self .base_model_prefix , self ) is not self :
1969+ base_model = getattr (self , self .base_model_prefix , self )
1970+ base_model .set_input_embeddings (value )
1971+ else :
1972+ raise NotImplementedError (
1973+ f"`set_input_embeddings` not auto‑handled for { self .__class__ .__name__ } ; please override in the subclass."
1974+ )
1975+
1976+ def get_output_embeddings (self ):
1977+ if not hasattr (self , "lm_head" ):
1978+ return None
1979+ try :
1980+ # Speech / vision backbones raise here, so we return None.
1981+ # Legit use of get_input_embs?
1982+ self .get_input_embeddings ()
1983+ except NotImplementedError :
1984+ return None
1985+ return self .lm_head
1986+
1987+ def set_output_embeddings (self , new_embeddings ):
1988+ """
1989+ Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
1990+ """
1991+ if getattr (self , "lm_head" ):
1992+ self .lm_head = new_embeddings
1993+
1994+
1995+ class PreTrainedModel (nn .Module , EmbeddingAccessMixin , ModuleUtilsMixin , PushToHubMixin , PeftAdapterMixin ):
19061996 r"""
19071997 Base class for all models.
19081998
@@ -2004,6 +2094,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
20042094 _supports_attention_backend = False
20052095 _can_record_outputs = None
20062096
2097+ # This attribute sets the default parameter to be
2098+
20072099 @property
20082100 @torch ._dynamo .allow_in_graph
20092101 def can_record_outputs (self ) -> dict [str , OutputRecorder ]:
@@ -2267,6 +2359,101 @@ def _from_config(cls, config, **kwargs):
22672359
22682360 return model
22692361
2362+ @classmethod
2363+ def _check_attn_implementation (cls , attn_implementation : Union [str , dict ]) -> Union [str , dict ]:
2364+ """
2365+ Checks that the requested attention implementation exists and tries to get the kernel from hub
2366+ if `attn_implementation` matches hf kernels pattern.
2367+ """
2368+ if isinstance (attn_implementation , str ) and re .match (r"^[^/:]+/[^/:]+:[^/:]+$" , attn_implementation ):
2369+ if not is_kernels_available ():
2370+ raise ValueError ("kernels is not installed. Please install it with `pip install kernels`." )
2371+
2372+ # Extract repo_id and kernel_name from the string
2373+ repo_id , kernel_name = attn_implementation .split (":" )
2374+ kernel_name = kernel_name .strip ()
2375+ repo_id = repo_id .strip ()
2376+
2377+ try :
2378+ kernel = get_kernel (repo_id )
2379+ ALL_ATTENTION_FUNCTIONS .register (f"kernel_{ repo_id .replace ('/' , '_' )} " , getattr (kernel , kernel_name ))
2380+ attn_implementation = f"kernel_{ repo_id .replace ('/' , '_' )} "
2381+ except FileNotFoundError as e :
2382+ logger .warning (
2383+ f"Could not find a kernel repository '{ repo_id } ' compatible with your devicein the hub: { e } . Using eager attention implementation instead."
2384+ )
2385+ attn_implementation = None # try to dispatch SDPA and fallback eager if not available
2386+ except AttributeError :
2387+ raise ValueError (
2388+ "the kernel function name or class specified in the attn_implementation argument is not valid. \
2389+ Please check the documentation for the correct format, \
2390+ and check that the kernel exports the class and the function correctly."
2391+ )
2392+ if (
2393+ not isinstance (attn_implementation , dict )
2394+ and attn_implementation not in ["eager" , None ] + ALL_ATTENTION_FUNCTIONS .valid_keys ()
2395+ ):
2396+ message = f'Specified `attn_implementation="{ attn_implementation } "` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
2397+ # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
2398+ if cls ._supports_flash_attn or getattr (cls , "_supports_flash_attn_2" , False ):
2399+ message += (
2400+ ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
2401+ ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
2402+ )
2403+ if cls ._supports_sdpa :
2404+ message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
2405+ if cls ._supports_flex_attn :
2406+ message += ', `"attn_implementation=flex_attention"` (implementation using torch\' s flex_attention)'
2407+ raise ValueError (message + "." )
2408+
2409+ return attn_implementation
2410+
2411+ def set_attention_implementation (self , attn_implementation : Union [str , dict ]):
2412+ """
2413+ Checks and dispatches to the requested attention implementation.
2414+ """
2415+ requested_attn_implementation = self ._check_attn_implementation (attn_implementation )
2416+
2417+ # Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
2418+ # keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
2419+ # See https://github.com/huggingface/transformers/pull/32238
2420+ for key in self .config .sub_configs .keys ():
2421+ sub_config = getattr (self .config , key )
2422+ curr_attn_implementation = (
2423+ requested_attn_implementation
2424+ if not isinstance (requested_attn_implementation , dict )
2425+ else requested_attn_implementation .get (key , None )
2426+ )
2427+ # For models with backbone sub-config might be not initialized. Set the requested att
2428+ # if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
2429+ if (
2430+ sub_config is not None
2431+ and sub_config ._attn_implementation_internal is None
2432+ and curr_attn_implementation is not None
2433+ ):
2434+ sub_config ._attn_implementation_internal = curr_attn_implementation
2435+
2436+ if requested_attn_implementation == "flash_attention_3" and self ._flash_attn_3_can_dispatch ():
2437+ self .config ._attn_implementation = "flash_attention_3"
2438+ if requested_attn_implementation == "flash_attention_2" and self ._flash_attn_2_can_dispatch ():
2439+ self .config ._attn_implementation = "flash_attention_2"
2440+ elif requested_attn_implementation == "flex_attention" and self ._flex_attn_can_dispatch ():
2441+ self .config ._attn_implementation = "flex_attention"
2442+ elif (
2443+ requested_attn_implementation in [None , "sdpa" ]
2444+ and not is_torch_xla_available ()
2445+ and self ._sdpa_can_dispatch (hard_check_only = requested_attn_implementation is not None )
2446+ ):
2447+ self .config ._attn_implementation = "sdpa"
2448+ elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS .valid_keys ():
2449+ self .config ._attn_implementation = requested_attn_implementation
2450+ elif isinstance (requested_attn_implementation , dict ):
2451+ self .config ._attn_implementation = requested_attn_implementation .get ("" , None )
2452+ else :
2453+ self .config ._attn_implementation = "eager"
2454+
2455+ self .config ._attn_implementation_autoset = True
2456+
22702457 @classmethod
22712458 def _set_default_torch_dtype (cls , dtype : torch .dtype ) -> torch .dtype :
22722459 """
@@ -2769,41 +2956,6 @@ def disable_input_require_grads(self):
27692956 """
27702957 self ._require_grads_hook .remove ()
27712958
2772- def get_input_embeddings (self ) -> nn .Module :
2773- """
2774- Returns the model's input embeddings.
2775-
2776- Returns:
2777- `nn.Module`: A torch module mapping vocabulary to hidden states.
2778- """
2779- base_model = getattr (self , self .base_model_prefix , self )
2780- if base_model is not self :
2781- return base_model .get_input_embeddings ()
2782- else :
2783- raise NotImplementedError
2784-
2785- def set_input_embeddings (self , value : nn .Module ):
2786- """
2787- Set model's input embeddings.
2788-
2789- Args:
2790- value (`nn.Module`): A module mapping vocabulary to hidden states.
2791- """
2792- base_model = getattr (self , self .base_model_prefix , self )
2793- if base_model is not self :
2794- base_model .set_input_embeddings (value )
2795- else :
2796- raise NotImplementedError
2797-
2798- def get_output_embeddings (self ) -> nn .Module :
2799- """
2800- Returns the model's output embeddings.
2801-
2802- Returns:
2803- `nn.Module`: A torch module mapping hidden states to vocabulary.
2804- """
2805- return None # Overwrite for models with output embeddings
2806-
28072959 def _init_weights (self , module ):
28082960 """
28092961 Initialize the weights. This method should be overridden by derived class and is
0 commit comments