Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom names for AdapterFusion layers #774

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ def __init__(self, *stack_layers: List[Union[AdapterCompositionBlock, str]]):


class Fuse(AdapterCompositionBlock):
def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]]):
def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]], name: Optional[str] = None):
super().__init__(*fuse_stacks)
self._name = name

# TODO-V2 pull this up to all block classes?
@property
def name(self):
return ",".join([c if isinstance(c, str) else c.last() for c in self.children])
if self._name:
return self._name
else:
return ",".join([c if isinstance(c, str) else c.last() for c in self.children])


class Split(AdapterCompositionBlock):
Expand Down
30 changes: 22 additions & 8 deletions src/adapters/configuration/model_adapters_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import logging
from collections.abc import Collection, Mapping
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from .. import __version__
from ..composition import AdapterCompositionBlock
Expand All @@ -27,6 +27,7 @@ def __init__(self, **kwargs):

self.fusions: Mapping[str, str] = kwargs.pop("fusions", {})
self.fusion_config_map = kwargs.pop("fusion_config_map", {})
self.fusion_name_map = kwargs.pop("fusion_name_map", {})

# TODO-V2 Save this with config?
self.active_setup: Optional[AdapterCompositionBlock] = None
Expand Down Expand Up @@ -131,7 +132,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None):
self.adapters[adapter_name] = config_name
logger.info(f"Adding adapter '{adapter_name}'.")

def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]:
def get_fusion(self, fusion_name: Union[str, List[str]]) -> Tuple[Optional[dict], Optional[list]]:
"""
Gets the config dictionary for a given AdapterFusion.

Expand All @@ -140,6 +141,7 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]:

Returns:
Optional[dict]: The AdapterFusion configuration.
Optional[list]: The names of the adapters to fuse.
"""
if isinstance(fusion_name, list):
fusion_name = ",".join(fusion_name)
Expand All @@ -149,20 +151,31 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]:
config = self.fusion_config_map.get(config_name, None)
else:
config = ADAPTERFUSION_CONFIG_MAP.get(config_name, None)

if fusion_name in self.fusion_name_map:
adapter_names = self.fusion_name_map[fusion_name]
else:
adapter_names = fusion_name.split(",")

return config, adapter_names
else:
config = None
return config
return None, None

def add_fusion(self, fusion_name: Union[str, List[str]], config: Optional[Union[str, dict]] = None):
def add_fusion(
self, adapter_names: List[str], config: Optional[Union[str, dict]] = None, fusion_name: Optional[str] = None
):
"""
Adds a new AdapterFusion.

Args:
fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse.
adapter_names (List[str]): The names of the adapters to fuse.
config (Optional[Union[str, dict]], optional): AdapterFusion config. Defaults to None.
fusion_name (Optional[str], optional): The name of the AdapterFusion. If not specified, will default to comma-separated adapter names.
"""
if isinstance(fusion_name, list):
fusion_name = ",".join(fusion_name)
if fusion_name is None:
fusion_name = ",".join(adapter_names)
else:
self.fusion_name_map[fusion_name] = adapter_names
if fusion_name in self.fusions:
raise ValueError(f"An AdapterFusion with the name '{fusion_name}' has already been added.")
if config is None:
Expand Down Expand Up @@ -218,6 +231,7 @@ def to_dict(self):
output_dict["fusion_config_map"][k] = v.to_dict()
else:
output_dict["fusion_config_map"][k] = copy.deepcopy(v)
output_dict["fusion_name_map"] = copy.deepcopy(self.fusion_name_map)
return output_dict

def __eq__(self, other):
Expand Down
12 changes: 9 additions & 3 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def save_to_state_dict(self, name: str):
if name not in self.model.adapters_config.fusions:
raise ValueError(f"No AdapterFusion with name '{name}' available.")

adapter_fusion_config = self.model.adapters_config.get_fusion(name)
adapter_fusion_config, _ = self.model.adapters_config.get_fusion(name)

config_dict = build_full_config(
adapter_fusion_config,
Expand Down Expand Up @@ -676,13 +676,14 @@ def save(self, save_directory: str, name: str, meta_dict=None):
else:
assert isdir(save_directory), "Saving path should be a directory where the head can be saved."

adapter_fusion_config = self.model.adapters_config.get_fusion(name)
adapter_fusion_config, adapter_names = self.model.adapters_config.get_fusion(name)

# Save the adapter fusion configuration
config_dict = build_full_config(
adapter_fusion_config,
self.model.config,
name=name,
adapter_names=adapter_names,
model_name=self.model.model_name,
model_class=self.model.__class__.__name__,
)
Expand Down Expand Up @@ -746,9 +747,14 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs):
config = self.weights_helper.load_weights_config(save_directory)

adapter_fusion_name = load_as or config["name"]
adapter_names = config.get("adapter_names", adapter_fusion_name)
if adapter_fusion_name not in self.model.adapters_config.fusions:
self.model.add_adapter_fusion(
adapter_fusion_name, config["config"], overwrite_ok=True, set_active=kwargs.pop("set_active", True)
adapter_names,
config["config"],
name=adapter_fusion_name,
overwrite_ok=True,
set_active=kwargs.pop("set_active", True),
)
else:
logger.warning("Overwriting existing adapter fusion module '{}'".format(adapter_fusion_name))
Expand Down
8 changes: 4 additions & 4 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:

def add_fusion_layer(self, adapter_names: Union[List, str]):
"""See BertModel.add_fusion_layer"""
adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",")
fusion_name = ",".join(adapter_names) if isinstance(adapter_names, list) else adapter_names
fusion_config, adapter_names = self.adapters_config.get_fusion(fusion_name)
if self.adapters_config.common_config_value(adapter_names, self.location_key):
fusion_config = self.adapters_config.get_fusion(adapter_names)
dropout_prob = fusion_config.dropout_prob or getattr(self.model_config, "attention_probs_dropout_prob", 0)
fusion = BertFusion(
fusion_config,
self.model_config.hidden_size,
dropout_prob,
)
fusion.train(self.training) # make sure training mode is consistent
self.adapter_fusion_layer[",".join(adapter_names)] = fusion
self.adapter_fusion_layer[fusion_name] = fusion

def delete_fusion_layer(self, adapter_names: Union[List, str]):
adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names)
Expand Down Expand Up @@ -223,7 +223,7 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0
context = ForwardContext.get_context()

# config of _last_ fused adapter is significant
fusion_config = self.adapters_config.get_fusion(adapter_setup.name)
fusion_config, _ = self.adapters_config.get_fusion(adapter_setup.name)
last = adapter_setup.last()
last_adapter = self.adapters[last]
hidden_states, query, residual = last_adapter.pre_forward(
Expand Down
27 changes: 16 additions & 11 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def add_adapter_fusion(
self,
adapter_names: Union[Fuse, list, str],
config=None,
name: str = None,
overwrite_ok: bool = False,
set_active: bool = False,
):
Expand All @@ -645,29 +646,33 @@ def add_adapter_fusion(
- a string identifying a pre-defined adapter fusion configuration
- a dictionary representing the adapter fusion configuration
- the path to a file containing the adapter fusion configuration
name (str, optional):
Name of the AdapterFusion layer. If not specified, the name is generated automatically from the fused adapter names.
overwrite_ok (bool, optional):
Overwrite an AdapterFusion layer with the same name if it exists. By default (False), an exception is
thrown.
set_active (bool, optional):
Activate the added AdapterFusion. By default (False), the AdapterFusion is added but not activated.
"""
if isinstance(adapter_names, Fuse):
if name is None:
name = adapter_names.name
adapter_names = adapter_names.children
elif isinstance(adapter_names, str):
adapter_names = adapter_names.split(",")
if name is None:
name = ",".join(adapter_names)

if isinstance(config, dict):
config = AdapterFusionConfig.from_dict(config) # ensure config is ok and up-to-date
# In case adapter already exists and we allow overwriting, explicitly delete the existing one first
if overwrite_ok and self.adapters_config.get_fusion(adapter_names) is not None:
self.delete_adapter_fusion(adapter_names)
self.adapters_config.add_fusion(adapter_names, config=config)
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names))
self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names))
if overwrite_ok and self.adapters_config.get_fusion(name)[0] is not None:
self.delete_adapter_fusion(name)
self.adapters_config.add_fusion(adapter_names, config=config, fusion_name=name)
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(name))
self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(name))
if set_active:
if not isinstance(adapter_names, list):
adapter_names = adapter_names.split(",")
self.set_active_adapters(Fuse(*adapter_names))
self.set_active_adapters(Fuse(*adapter_names, name=name))

def delete_adapter(self, adapter_name: str):
"""
Expand Down Expand Up @@ -700,7 +705,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]):
adapter_names (Union[Fuse, list, str]): AdapterFusion layer to delete.
"""
if isinstance(adapter_names, Fuse):
adapter_fusion_name = ",".join(adapter_names.children)
adapter_fusion_name = adapter_names.name
elif isinstance(adapter_names, list):
adapter_fusion_name = ",".join(adapter_names)
elif isinstance(adapter_names, str):
Expand Down Expand Up @@ -766,7 +771,7 @@ def save_adapter_fusion(
ValueError: If the given AdapterFusion name is invalid.
"""
if isinstance(adapter_names, Fuse):
adapter_fusion_name = ",".join(adapter_names.children)
adapter_fusion_name = adapter_names.name
elif isinstance(adapter_names, list):
adapter_fusion_name = ",".join(adapter_names)
elif isinstance(adapter_names, str):
Expand Down Expand Up @@ -929,7 +934,7 @@ def save_all_adapter_fusions(
"""
os.makedirs(save_directory, exist_ok=True)
for name in self.adapters_config.fusions:
adapter_fusion_config = self.adapters_config.get_fusion(name)
adapter_fusion_config, _ = self.adapters_config.get_fusion(name)
h = get_adapter_config_hash(adapter_fusion_config)
save_path = join(save_directory, name)
if meta_dict:
Expand Down
83 changes: 83 additions & 0 deletions tests/test_adapter_fusion_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,86 @@ def test_output_adapter_fusion_attentions(self):
self.assertEqual(len(per_layer_scores), 1)
for k, v in per_layer_scores.items():
self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k)

def test_add_adapter_fusion_custom_name(self):
config_name = "seq_bn"
model = self.get_model()
model.eval()

name1 = f"{config_name}-1"
name2 = f"{config_name}-2"
model.add_adapter(name1, config=config_name)
model.add_adapter(name2, config=config_name)

# adapter is correctly added to config
self.assertTrue(name1 in model.adapters_config)
self.assertTrue(name2 in model.adapters_config)

# add fusion with default name
model.add_adapter_fusion([name1, name2])
model.to(torch_device)

# check forward pass
input_data = self.get_input_samples(config=model.config)
model.set_active_adapters(Fuse(name1, name2))
fusion_default_ref_output = model(**input_data)

# add fusion with custom name
model.add_adapter_fusion([name1, name2], name="custom_name_fusion")
model.to(torch_device)

self.assertIn(f"{name1},{name2}", model.adapters_config.fusions)
self.assertIn("custom_name_fusion", model.adapters_config.fusions)
self.assertIn("custom_name_fusion", model.adapters_config.fusion_name_map)

# check forward pass
model.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion"))
fusion_custom_output = model(**input_data)
model.set_active_adapters(Fuse(name1, name2))
fusion_default_output = model(**input_data)
model.set_active_adapters(None)
base_output = model(**input_data)

self.assertFalse(torch.equal(fusion_default_ref_output[0], base_output[0]))
self.assertTrue(torch.equal(fusion_default_ref_output[0], fusion_default_output[0]))
self.assertFalse(torch.equal(fusion_custom_output[0], fusion_default_output[0]))
self.assertFalse(torch.equal(fusion_custom_output[0], base_output[0]))

# delete only the custom fusion
model.delete_adapter_fusion(Fuse(name1, name2, name="custom_name_fusion"))
# model.delete_adapter_fusion("custom_name_fusion")

self.assertIn(f"{name1},{name2}", model.adapters_config.fusions)
self.assertNotIn("custom_name_fusion", model.adapters_config.fusions)

def test_load_adapter_fusion_custom_name(self):
model1 = self.get_model()
model1.eval()

name1 = "name1"
name2 = "name2"
model1.add_adapter(name1)
model1.add_adapter(name2)

model2 = copy.deepcopy(model1)
model2.eval()

model1.add_adapter_fusion([name1, name2], name="custom_name_fusion")
model1.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion"))

with tempfile.TemporaryDirectory() as temp_dir:
model1.save_adapter_fusion(temp_dir, "custom_name_fusion")
# also tests that set_active works
model2.load_adapter_fusion(temp_dir, set_active=True)

# check if adapter was correctly loaded
self.assertEqual(model1.adapters_config.fusions.keys(), model2.adapters_config.fusions.keys())

# check equal output
in_data = self.get_input_samples(config=model1.config)
model1.to(torch_device)
model2.to(torch_device)
output1 = model1(**in_data)
output2 = model2(**in_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))
Loading