diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index f52b5cf1d58e..998af1491cff 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -15,9 +15,9 @@ import inspect import json import os -import re from typing import Any, Literal +from ..core_model_loading import WeightRenaming, build_glob_alternation, repl from ..utils import ( CONFIG_NAME, cached_file, @@ -302,12 +302,10 @@ def load_adapter( else: new_key = key - if key_mapping: # TODO dynamic weight loader for adapters - for pattern, replacement in key_mapping.items(): - new_key, n_replace = re.subn(pattern, replacement, new_key) - # Early exit of the loop - if n_replace > 0: - break + if key_mapping: + renamings = [entry for entry in key_mapping if isinstance(entry, WeightRenaming)] + rename_alt, _, rename_by_group = build_glob_alternation(renamings) + new_key = rename_alt.sub(lambda m: repl(m, rename_by_group), new_key).replace("\\", "") # For hotswapping, we need the adapter name to be present in the state dict keys if hotswap: