Skip to content

Commit

Permalink
Fixed bug with ddp training.
Browse files Browse the repository at this point in the history
  • Loading branch information
ertkonuk committed Dec 14, 2023
1 parent 635b014 commit 628d84e
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,30 @@
"all": "all",
}


def get_target_modules(lora_cfg):
target_modules = lora_cfg.get("target_modules", ["attention_qkv"])
if PEFT_MODULE_MAP["attention"] in target_modules:
target_modules.extend([PEFT_MODULE_MAP['qkv_module'], PEFT_MODULE_MAP['dense_module']])
target_modules.remove(PEFT_MODULE_MAP["attention"])
if PEFT_MODULE_MAP["mlp"] in target_modules:
target_modules.extend([PEFT_MODULE_MAP['hto4h_module'], PEFT_MODULE_MAP['4htoh_module']])
target_modules.remove(PEFT_MODULE_MAP["mlp"])
if PEFT_MODULE_MAP["all"] in target_modules:
target_modules.extend(
[
PEFT_MODULE_MAP['qkv_module'],
PEFT_MODULE_MAP['dense_module'],
PEFT_MODULE_MAP['hto4h_module'],
PEFT_MODULE_MAP['4htoh_module'],
]
)
target_modules.remove(PEFT_MODULE_MAP["all"])

return list(set(target_modules)) # only return unique modules/elements
original_target_modules = lora_cfg.get("target_modules", ["attention_qkv"])
target_modules = []

for module in original_target_modules:
if module == PEFT_MODULE_MAP["attention"]:
if PEFT_MODULE_MAP['qkv_module'] not in target_modules:
target_modules.append(PEFT_MODULE_MAP['qkv_module'])
if PEFT_MODULE_MAP['dense_module'] not in target_modules:
target_modules.append(PEFT_MODULE_MAP['dense_module'])
elif module == PEFT_MODULE_MAP["mlp"]:
if PEFT_MODULE_MAP['hto4h_module'] not in target_modules:
target_modules.append(PEFT_MODULE_MAP['hto4h_module'])
if PEFT_MODULE_MAP['4htoh_module'] not in target_modules:
target_modules.append(PEFT_MODULE_MAP['4htoh_module'])
elif module == PEFT_MODULE_MAP["all"]:
for sub_module in [PEFT_MODULE_MAP['qkv_module'], PEFT_MODULE_MAP['dense_module'], PEFT_MODULE_MAP['hto4h_module'], PEFT_MODULE_MAP['4htoh_module']]:
if sub_module not in target_modules:
target_modules.append(sub_module)
else:
if module not in target_modules:
target_modules.append(module)

return target_modules

class PEFTConfig:
# superclass for adapter name and config
Expand Down

0 comments on commit 628d84e

Please sign in to comment.