Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tune specific params in the base model #7745

Merged
merged 10 commits into from
Jan 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ model:

ia3_tuning:
layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers

selective_tuning:
tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre

data:
train_ds:
Expand Down
24 changes: 21 additions & 3 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class NLPAdapterModelMixin:

def __init__(self, *args, **kwargs):
self.use_peft = False
self.tunable_base_param_names = []
self.setup_complete = False
self.use_ptuning_only = False
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -192,10 +193,21 @@ def add_adapter(self, peft_cfgs: Union[PEFTConfig, List[PEFTConfig]]):

logging.info(f"After adding PEFT params:\n{self.summarize()}")
self.adapter_keys = self._get_all_keys() - self.base_keys
self.tunable_base_param_keys = set()

for cfg in peft_cfgs:
if cfg.weight_tying:
self.tie_weights(cfg)

if cfg.tunable_base_param_names:
for n, p in self.named_parameters():
for tpn in cfg.tunable_base_param_names:
if (
f".{tpn}." in n
): # TODO: simplistic param name matching, should support regex-like syntax @adithyare
self.tunable_base_param_keys.add(n)
p.requires_grad = True # We set these to true to trigger setup_optimizer_param_groups

self.use_peft = True

def _get_config_and_state_dict_from_nemo(self, filepath, map_location):
Expand Down Expand Up @@ -239,6 +251,12 @@ def setup_optimizer_param_groups(self):
module.set_enabled_adapters(enabled=True)
module.unfreeze_enabled_adapters() # selectively unfreeze the adapter modules.
opt_params += [p for p in module.parameters() if p.requires_grad]

for name, param in self.named_parameters():
if name in self.tunable_base_param_keys:
param.requires_grad = True
opt_params += [param]

self._optimizer_param_groups = ({"params": opt_params},)
logging.info(f"Optimizer groups set:\n{self.summarize()}")
else:
Expand Down Expand Up @@ -282,7 +300,7 @@ def load_adapters(
), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument."
peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)]
self.add_adapter(peft_cfgs)
assert set(state_dict.keys()) == self.adapter_keys
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
super().load_state_dict(state_dict, strict=False)

def tie_weights(self, peft_cfg):
Expand Down Expand Up @@ -328,7 +346,7 @@ def get_peft_state_dict(self):
"""
state_dict = self.model.state_dict(prefix=self.model_prefix)
peft_state_dict = {}
for k in self.adapter_keys:
for k in self.adapter_keys.union(self.tunable_base_param_keys):
# state_dict keys needs to be in non-O2 format and will be corrected in PEFTSaveRestoreConnector if O2=True
new_k = k.replace("model.module.", "model.", 1)
peft_state_dict[new_k] = state_dict[k]
Expand Down Expand Up @@ -360,7 +378,7 @@ def load_state_dict(self, state_dict, strict: bool = True):
# setting strict=False will ignore the missing keys (which are not being updated anyway)
# explicitly check if state_dict.keys matches all the expected self.adapter_keys since we don't have the
# safety in strict=True anymore.
assert set(state_dict.keys()) == self.adapter_keys
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
super().load_state_dict(state_dict, strict=False)
else:
super().load_state_dict(state_dict, strict=True)
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,18 @@ def __init__(self, peft_cfg: DictConfig, name_key_to_cfg: Dict):

self.layer_selection = peft_cfg.get("layer_selection", None)
self.weight_tying = peft_cfg.get("weight_tying", False)
self.tunable_param_names = peft_cfg.get("tunable_param_names", [])

def get_config_dict(self):
return self.name_key_to_cfg


class SelectivePEFTConfig(PEFTConfig):
def __init__(self, cfg):
selective_cfg = cfg.peft.selective_tuning
super().__init__(selective_cfg, name_key_to_cfg={})


class LoraPEFTConfig(PEFTConfig):
def __init__(self, cfg):
lora_cfg = cfg.peft.lora_tuning
Expand Down Expand Up @@ -195,6 +202,7 @@ def __init__(self, cfg):
"ia3": IA3PEFTConfig,
"ptuning": PtuningPEFTConfig,
"lora": LoraPEFTConfig,
"selective": SelectivePEFTConfig,
'none': None,
None: None,
}
Loading