-
Notifications
You must be signed in to change notification settings - Fork 33.5k
Make gradient-checkpoint enabling tolerant of models without get_input_embeddings #42558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d5909a7
a9eb634
b520cc7
7ce45fe
d41e204
0e93a61
de8ff71
b2618b3
5d61150
ef55499
44ab4c6
fe89c1c
2920d00
b8ccd0f
b5ae5a6
d209ff5
79665d4
844c707
fcc84a4
e970fad
b4f5c15
0246a70
81940dd
284189a
1079eef
f479598
73b4f5d
0e7086f
18d44ba
00cc669
d9d7442
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1063,54 +1063,52 @@ def get_input_embeddings(self) -> nn.Module: | |
| `nn.Module`: A torch module mapping vocabulary to hidden states. | ||
| """ | ||
|
|
||
| # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer | ||
| # for most NLP models), and if so, return it. | ||
|
|
||
| name = getattr(self, "_input_embed_layer", "embed_tokens") | ||
|
|
||
| # 1) Direct attribute (most NLP models). | ||
| if (default_embedding := getattr(self, name, None)) is not None: | ||
| return default_embedding | ||
| # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration` | ||
| # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models). | ||
| if hasattr(self, "embeddings") and hasattr(self.embeddings, name): | ||
| return getattr(self.embeddings, name) | ||
| # 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides). | ||
| if hasattr(self, "model") and hasattr(self.model, name): | ||
| return getattr(self.model, name) | ||
|
|
||
| if hasattr(self, "model") and hasattr(self.model, "embed_tokens"): | ||
| return self.model.embed_tokens | ||
| if hasattr(self, "base_model"): | ||
| base_model = self.base_model | ||
| if base_model is not None and base_model is not self: | ||
| return base_model.get_input_embeddings() | ||
|
|
||
| # 3) vanilla decoder‑only architectures | ||
| elif hasattr(self, "embed_tokens"): | ||
| return self.embed_tokens | ||
| else: | ||
| base_model = getattr(self, "base_model_prefix", None) | ||
| if base_model is not None: | ||
| base_model = getattr(self, base_model, None) | ||
| if base_model is not None and base_model is not self: | ||
| return base_model.get_input_embeddings() | ||
| raise NotImplementedError( | ||
| f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; " | ||
| "please override in the subclass." | ||
| ) | ||
| raise NotImplementedError( | ||
| f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass." | ||
| ) | ||
|
|
||
| def set_input_embeddings(self, value: nn.Module): | ||
| """Fallback setter that handles **~70%** of models in the code-base. | ||
|
|
||
| Order of attempts: | ||
| 1. `self.model.embed_tokens` | ||
| 2. `self.embed_tokens` | ||
| 3. delegate to the *base model* if one exists | ||
| 4. otherwise raise `NotImplementedError` so subclasses still can (and | ||
| 1. `self.<_input_embed_layer>` (direct attribute) | ||
| 2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models) | ||
| 3. `self.model.<_input_embed_layer>` (encoder/decoder models) | ||
| 4. delegate to the *base model* if one exists | ||
| 5. otherwise raise `NotImplementedError` so subclasses still can (and | ||
| should) override for exotic layouts. | ||
| """ | ||
|
|
||
| # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration` | ||
| name = getattr(self, "_input_embed_layer", "embed_tokens") | ||
| if hasattr(self, "model") and hasattr(self.model, name): | ||
| setattr(self.model, name, value) | ||
| # 2) as well as vanilla decoder‑only architectures | ||
| elif hasattr(self, name): | ||
| # 1) Direct attribute (most NLP models) | ||
| if hasattr(self, name): | ||
| setattr(self, name, value) | ||
| # 3) recurse once into the registered *base* model (e.g. for encoder/decoder) | ||
| elif getattr(self, self.base_model_prefix, self) is not self: | ||
| base_model = getattr(self, self.base_model_prefix, self) | ||
| base_model.set_input_embeddings(value) | ||
| # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models) | ||
| elif hasattr(self, "embeddings") and hasattr(self.embeddings, name): | ||
| setattr(self.embeddings, name, value) | ||
| # 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration` | ||
| elif hasattr(self, "model") and hasattr(self.model, name): | ||
| setattr(self.model, name, value) | ||
| # 4) recurse once into the registered *base* model (e.g. for encoder/decoder) | ||
| elif hasattr(self, "base_model") and self.base_model is not self: | ||
| self.base_model.set_input_embeddings(value) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass." | ||
|
|
@@ -2043,14 +2041,18 @@ def make_inputs_require_grads(module, input, output): | |
|
|
||
| hooks = [] | ||
| seen_modules = set() | ||
| found_embeddings = False | ||
|
|
||
| for module in self.modules(): | ||
| if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")): | ||
| continue | ||
|
|
||
| input_embeddings = module.get_input_embeddings() | ||
| try: | ||
| input_embeddings = module.get_input_embeddings() | ||
| except NotImplementedError: | ||
| continue | ||
|
|
||
| if input_embeddings is None: | ||
| if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"): | ||
| continue | ||
|
|
||
| embedding_id = id(input_embeddings) | ||
|
|
@@ -2059,11 +2061,18 @@ def make_inputs_require_grads(module, input, output): | |
|
|
||
| seen_modules.add(embedding_id) | ||
| hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads)) | ||
| found_embeddings = True | ||
|
|
||
| self._require_grads_hooks = hooks | ||
| if hooks: | ||
| # for BC | ||
| self._require_grads_hook = hooks[0] | ||
| if not found_embeddings: | ||
| logger.warning_once( | ||
| f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token " | ||
| "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully " | ||
| "support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings." | ||
| ) | ||
|
|
||
| def disable_input_require_grads(self): | ||
| """ | ||
|
|
@@ -3000,7 +3009,10 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): | |
| "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." | ||
| ) | ||
|
|
||
| if getattr(self, "_hf_peft_config_loaded", False): | ||
| needs_embedding_grads = self.main_input_name == "input_ids" | ||
| # we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all) | ||
| enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False) | ||
| if enable_input_grads: | ||
|
Comment on lines
+3012
to
+3015
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, for my understanding, why do we always need to enable grads when doing GC training with text models?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't always, but we do with reentrant checkpointing. IIUC it's not to actualy use these gradients, it's that torch.utils.checkpoint needs at least one input and one output to actually have gradients, else the checkpointed part will not have a gradient. |
||
| # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True | ||
| # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 | ||
| # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no simple way around this unfortunately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oke, I think with the warning below, it is more explicit