From 628d84ea3adeb41956b832fb9b1e7c1b73aab553 Mon Sep 17 00:00:00 2001 From: Tugrul Konuk Date: Wed, 13 Dec 2023 22:30:51 -0600 Subject: [PATCH] Fixed bug with ddp training. --- nemo/collections/nlp/parts/peft_config.py | 42 ++++++++++++----------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index d097f780a97d..abd6aec4e201 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -50,28 +50,30 @@ "all": "all", } - def get_target_modules(lora_cfg): - target_modules = lora_cfg.get("target_modules", ["attention_qkv"]) - if PEFT_MODULE_MAP["attention"] in target_modules: - target_modules.extend([PEFT_MODULE_MAP['qkv_module'], PEFT_MODULE_MAP['dense_module']]) - target_modules.remove(PEFT_MODULE_MAP["attention"]) - if PEFT_MODULE_MAP["mlp"] in target_modules: - target_modules.extend([PEFT_MODULE_MAP['hto4h_module'], PEFT_MODULE_MAP['4htoh_module']]) - target_modules.remove(PEFT_MODULE_MAP["mlp"]) - if PEFT_MODULE_MAP["all"] in target_modules: - target_modules.extend( - [ - PEFT_MODULE_MAP['qkv_module'], - PEFT_MODULE_MAP['dense_module'], - PEFT_MODULE_MAP['hto4h_module'], - PEFT_MODULE_MAP['4htoh_module'], - ] - ) - target_modules.remove(PEFT_MODULE_MAP["all"]) - - return list(set(target_modules)) # only return unique modules/elements + original_target_modules = lora_cfg.get("target_modules", ["attention_qkv"]) + target_modules = [] + + for module in original_target_modules: + if module == PEFT_MODULE_MAP["attention"]: + if PEFT_MODULE_MAP['qkv_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['qkv_module']) + if PEFT_MODULE_MAP['dense_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['dense_module']) + elif module == PEFT_MODULE_MAP["mlp"]: + if PEFT_MODULE_MAP['hto4h_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['hto4h_module']) + if PEFT_MODULE_MAP['4htoh_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['4htoh_module']) + elif module == PEFT_MODULE_MAP["all"]: + for sub_module in [PEFT_MODULE_MAP['qkv_module'], PEFT_MODULE_MAP['dense_module'], PEFT_MODULE_MAP['hto4h_module'], PEFT_MODULE_MAP['4htoh_module']]: + if sub_module not in target_modules: + target_modules.append(sub_module) + else: + if module not in target_modules: + target_modules.append(module) + return target_modules class PEFTConfig: # superclass for adapter name and config