Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Vera #763

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"SeqBnInvConfig",
"StaticAdapterFusionConfig",
"UniPELTConfig",
"VeraConfig",
],
"context": [
"AdapterSetup",
Expand Down Expand Up @@ -181,6 +182,7 @@
SeqBnInvConfig,
StaticAdapterFusionConfig,
UniPELTConfig,
VeraConfig,
)
from .context import AdapterSetup, ForwardContext
from .heads import (
Expand Down
34 changes: 33 additions & 1 deletion src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,20 @@ class LoRAConfig(AdapterConfig):
(addition of decomposed matrix, as in LoRA) or "scale" (element-wise multiplication of vector, as in
(IA)^3). "scale" can only be used together with r=1. Defaults to "add".
init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules.
Currently, this can be either "lora" (default) or "bert".
Currently, this can be either "lora" (default) or "bert", or "vera".
use_gating (:obj:`bool`, optional):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
`merge_adapter()`.
d (:obj:`bool` or :obj:`float`, optional):
The value of d used in the VeraConfig. Defaults to None. Places a trainable
scaling parameter `d` before the decomposition matrix A to allow scaling of the
internal weights.

b (:obj:`bool` or :obj:`float`, optional):
The value of b used in the VeraConfig. Defaults to None. Places a trainable
scaling parameter `b` before the decomposition matrix B to allow scaling of the
internal weights.
"""

architecture: Optional[str] = "lora"
Expand All @@ -499,6 +508,8 @@ class LoRAConfig(AdapterConfig):
composition_mode: str = "add"
init_weights: str = "lora"
use_gating: bool = False
d: Union[bool, float] = None
b: Union[bool, float] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we name these "vera_b" and "vera_d", to make more obvious what these are related to?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why can these be bools? ie what happens when I set d=True, b=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, i think this is a typo based on a previous idea I had which i scraped later.. thanks



@dataclass(eq=False)
Expand All @@ -523,6 +534,27 @@ class IA3Config(LoRAConfig):
use_gating: bool = False


@dataclass(eq=False)
class VeraConfig(LoRAConfig):
"""
Lora Config that applies vector-based random matrix adaptation. It adds
trainable matrices 'd' and 'b' while keeping the original LoRA matrices
frozen, random, and shared across layers. See more through their paper:
https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied
since we are still initializing decomposition matrices A and B.
The `composition_mode` parameter should also be set to `add`.
Copy link
Member

@calpt calpt Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the paper link still needs updating :)

"""

selfattn_lora: bool = True
intermediate_lora: bool = False
output_lora: bool = False

r: int = 8
d: Union[bool, float] = 0.1
b: Union[bool, float] = 0.0
init_weights: str = "vera"


@dataclass(eq=False)
class ReftConfig(AdapterConfig):
"""
Expand Down
122 changes: 120 additions & 2 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ..composition import Average, BatchSplit, Parallel, Stack
from ..configuration import LoRAConfig, ModelAdaptersConfig
from ..context import ForwardContext
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .utils import dequantize_bnb_weight

Expand Down Expand Up @@ -68,6 +69,9 @@ def __init__(
elif config.init_weights == "ia3":
nn.init.ones_(self.lora_A)
nn.init.ones_(self.lora_B)
elif config.init_weights == "vera":
nn.init.kaiming_uniform_(self.lora_A)
nn.init.kaiming_uniform_(self.lora_B)
else:
raise ValueError("Unknown init_weights type: {}".format(config.init_weights))

Expand All @@ -90,6 +94,7 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
return weights - added * self.scaling

def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
print("triggered")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove?

if hidden_states is None:
hidden_states = layer_input
hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B)
Expand Down Expand Up @@ -131,7 +136,7 @@ def __init__(
# For compatibility with LoRA, allow all init_weights types here.
# Usually should be "ia3".
if config.init_weights == "lora":
logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.")
logger.warning("(IA)^3 module initialized with LoRA zero init. Ignore if this is intended.")
nn.init.zeros_(self.lora_B)
elif config.init_weights == "bert":
nn.init.normal_(self.lora_B, std=0.02)
Expand Down Expand Up @@ -174,6 +179,111 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens
return hidden_states, gate


class Vera(nn.Module):
def __init__(
self,
lora_A_shape,
lora_B_shape,
config: LoRAConfig,
gating_heads: int = 1,
):
super().__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also add an assert for composition mode "add" here (same as in LoRA init), just to make sure

self.d = config.d
self.b = config.b
self.r = config.r
self.alpha = config.alpha
self.use_gating = config.use_gating

# Optional dropout
if config.dropout > 0.0:
self.lora_dropout = nn.Dropout(p=config.dropout)

self.lora_A_shape = lora_A_shape
self.lora_B_shape = lora_B_shape
self.d_shape = self.lora_A_shape[0]
self.b_shape = self.lora_B_shape[0]

# Actual trainable parameters
self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d))
self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b))
self.scaling = self.alpha / self.r

if self.use_gating:
self.gate = nn.Linear(lora_A_shape[-1], gating_heads)
nn.init.normal_(self.gate.weight, std=0.02)

@property
def delta_w(self) -> torch.Tensor:
parameters = ForwardContext.get_context().shared_parameters[self.name]
lora_A = parameters["lora_A"]
lora_B = parameters["lora_B"]
return self.vera_B @ lora_B @ self.vera_D @ lora_A

def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor:
"""Performs the composition operation between existing and injected weights."""
if scaling is None:
scaling = self.scaling
return weights + added * scaling

def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
"""Inverts the composition operation between existing and injected weights."""
return weights - added * self.scaling

def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
parameters = ForwardContext.get_context().shared_parameters[self.name]
lora_A = parameters["lora_A"]
lora_B = parameters["lora_B"]

if hidden_states is None:
hidden_states = layer_input

if getattr(self, "lora_dropout"):
hidden_states = self.lora_dropout(hidden_states)

hidden_states = hidden_states @ self.vera_B @ lora_B @ self.vera_D @ lora_A
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the order be reversed here? ie we matmul hidden states with lora_A -> vera_d -> lora_B -> vera_b, according to §3.1 (2) of the paper?


if self.use_gating:
gate = torch.sigmoid(self.gate(layer_input))
gate = torch.mean(gate, dim=1).unsqueeze(-1)
hidden_states = hidden_states * gate
else:
gate = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this is likely merged after #770, the same fix from there should be applied here


return hidden_states, gate

def set_vera_adapter_name(self, name):
self.name = name


def init_shared_Vera_parameters(model_config, adapter_config, device):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ideally lower-case "v" in the middle of method names

hidden_size = model_config.hidden_size
r = adapter_config["r"]

parameters = nn.ParameterDict()

# initialize frozen, random tensors A, B
parameters["lora_A"] = torch.zeros(r, hidden_size).to(device)
parameters["lora_B"] = torch.zeros(hidden_size, r).to(device)

if adapter_config["init_weights"] == "lora":
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(parameters["lora_A"], a=math.sqrt(5))
nn.init.zeros_(parameters["lora_B"])
elif adapter_config["init_weights"] == "bert":
nn.init.normal_(parameters["lora_A"], std=0.02)
nn.init.normal_(parameters["lora_B"], std=0.02)
elif adapter_config["init_weights"] == "ia3":
nn.init.ones_(parameters["lora_A"])
nn.init.ones_(parameters["lora_B"])
elif adapter_config["init_weights"] == "vera":
nn.init.kaiming_uniform_(parameters["lora_A"])
nn.init.kaiming_uniform_(parameters["lora_B"])
else:
raise ValueError("Unknown init_weights type: {}".format(adapter_config["init_weights"]))

return parameters


class LoRALayer(AdapterLayerBase):
adapter_modules_name = "loras"

Expand All @@ -199,6 +309,7 @@ def _get_lora_shapes(self, config: LoRAConfig):

def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
self.layer_idx = layer_idx

lora_config = self.adapters_config.match(
adapter_name,
config_type=LoRAConfig,
Expand All @@ -207,7 +318,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
)
if lora_config is not None and self._check_lora_location(lora_config):
if lora_config.composition_mode == "add":
lora_cls = LoRA
if isinstance(lora_config.d, float) or isinstance(lora_config.b, float):
lora_cls = Vera
else:
lora_cls = LoRA
elif lora_config.composition_mode == "scale":
lora_cls = IA3
else:
Expand All @@ -217,6 +331,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
lora_config,
gating_heads=self.get_n_heads(lora_config),
)
# if we're using Vera, then set the adapter name into the Vera object
if lora_cls == Vera:
lora.set_vera_adapter_name(name=adapter_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels a bit hacky to do this only for vera as the name is not specific to this type. what do you think of always passing the name directly to the __init__ method of each module class (for all LoRA, Vera, IA3) and setting self.name directly there?
that might be cleaner long-term as we might want to use the name in LoRA as well in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thought about that idea as well but opted for now to implement this idea first since right now lora and IA3 don't use self.name. I'll refactor it as you said. Thanks!


lora.train(self.training)
lora = lora.to(self.weight.device)
self.loras[adapter_name] = lora
Expand Down
14 changes: 11 additions & 3 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch import nn

from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig
from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig, VeraConfig
from transformers import GenerationConfig
from transformers.modeling_outputs import ModelOutput
from transformers.utils import is_accelerate_available
Expand All @@ -22,7 +22,7 @@
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
from .methods.adapter_layer_base import AdapterLayerBase
from .methods.bottleneck import BottleneckLayer
from .methods.lora import LoRALayer
from .methods.lora import LoRALayer, init_shared_Vera_parameters
from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters
from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
from .methods.prompt_tuning import PromptTuningLayer
Expand Down Expand Up @@ -610,13 +610,21 @@ def _add_adapter_weights(self, adapter_name: str):
)
else:
raise ValueError(
"The model has different hidden sizes {}. Sharing comapcter weights is only possible if"
"The model has different hidden sizes {}. Sharing compacter weights is only possible if"
" the hidden_sizes match.".format(hidden_sizes)
)
else:
self.base_model.shared_parameters[adapter_name] = init_shared_parameters(
adapter_config, self.config.hidden_size, self.device
)

# Vera Initialization
if self.adapters_config.match(adapter_name, VeraConfig):
adapter_config = self.adapters_config.match(adapter_name, VeraConfig)
self.base_model.shared_parameters[adapter_name] = init_shared_Vera_parameters(
self.config, adapter_config, self.device
)

# Prefix Tuning
for module in self.modules():
if isinstance(module, PrefixTuningPool):
Expand Down
Loading