Skip to content

Commit

Permalink
fix: apply & change @lenglaender solution in adapter-hub#759
Browse files Browse the repository at this point in the history
  • Loading branch information
akatief committed Nov 11, 2024
1 parent 0c9701e commit 17fb225
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import functools
import threading
from typing import ContextManager

from .composition import parse_composition, parse_heads_from_composition


class AdapterSetup:
class AdapterSetup(ContextManager):
"""
Represents an adapter setup of a model including active adapters and active heads. This class is intended to be
used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_context_head_setup(cls):
return None


class ForwardContext:
class ForwardContext(ContextManager):
"""
Holds context information during a forward pass through a model. This class should be used via the
``ForwardContext.wrap()`` method.
Expand Down
62 changes: 62 additions & 0 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import contextlib
import functools
import inspect
import logging
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from functools import partial
from os.path import join
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig
from transformers import GenerationConfig
Expand Down Expand Up @@ -1450,6 +1454,64 @@ def save_pretrained(
del self.config.adapters


def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": False}

# >>> START AH Changes <<<
if "use_reentrant" not in gradient_checkpointing_kwargs:
# use_reentrant must be set.
gradient_checkpointing_kwargs["use_reentrant"] = False
else:
if gradient_checkpointing_kwargs["use_reentrant"]:
raise ValueError(
"Gradient checkpointing with use_reentrant=True is not supported. For gradient checkpointing, we need to set context_fn, which is only supported by PyTorch when use_reentrant is set to False."
)

def gradient_checkpointing_function(function, *args, **kwargs):
context = ForwardContext.get_context()
context_fn = lambda: (contextlib.nullcontext(), context)
return checkpoint(function, *args, context_fn=context_fn, **kwargs)

gradient_checkpointing_func = functools.partial(
gradient_checkpointing_function, **gradient_checkpointing_kwargs
)
# >>> END AH Changes <<<

# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters

if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()


@inherit_doc
class ModelBaseAdaptersMixin(ModelAdaptersMixin):
add_base_adapters = True
Expand Down

0 comments on commit 17fb225

Please sign in to comment.