Skip to content

Commit

Permalink
merge peft params (PaddlePaddle#6678)
Browse files Browse the repository at this point in the history
Co-authored-by: DesmonDay <[email protected]>
  • Loading branch information
lugimzzz and DesmonDay authored Aug 10, 2023
1 parent f59274a commit 8398a5f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
4 changes: 2 additions & 2 deletions docs/peft.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ key function:
--save_directory
保存目录的路径
--merge_tensor_parallel
是否合并多卡参数
是否合并多卡参数,默认为True

如果merge_tensor_parallel为真且模型的配置中的张量并行度大于1,则获取可训练的state_dict,并使用_merge_trainable_tensor_parallel方法合并张量并行训练的state_dict。如果merge_tensor_parallel为真且模型的张量并行度大于1,只有主进程会进行保存操作。

Expand Down Expand Up @@ -253,7 +253,7 @@ key function
--save_directory
保存目录的路径
--merge_tensor_parallel
是否合并多卡参数
是否合并多卡参数,默认为True

如果merge_tensor_parallel为真且模型的配置中的张量并行度大于1,则获取可训练的state_dict,并使用_merge_trainable_tensor_parallel方法合并张量并行训练的state_dict。如果merge_tensor_parallel为真且模型的张量并行度大于1,只有主进程会进行保存操作。

Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _convert_tensor_parallel(self, lora_state_dict):
lora_state_dict[name] = action(tensor)
return lora_state_dict

def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = True, **kwargs):
variant = kwargs.get("variant", None)
is_main_process = kwargs.get("is_main_process", paddle.distributed.get_rank() == 0)

Expand Down Expand Up @@ -214,7 +214,7 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
# save lora config
if is_main_process:
self.lora_config.save_pretrained(save_directory)
self.lora_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree
self.lora_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree

def _find_and_replace_module(self, model, module_name, lora_config, enable_lora):
parent_module = model
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/peft/prefix/prefix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def from_pretrained(

return prefix_model

def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = True, **kwargs):
variant = kwargs.get("variant", None)
is_main_process = kwargs.get("is_main_process", paddle.distributed.get_rank() == 0)

Expand Down Expand Up @@ -352,9 +352,13 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
# save prefix config & past key values
if is_main_process:
self.prefix_config.save_pretrained(save_directory)
self.prefix_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree
np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_values)

if self.model.base_model_prefix == "chatglm2":
self.prefix_config.tensor_parallel_degree = -1
else:
self.prefix_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree

def set_state_dict(self, state_dict):
self.prefix_encoder.set_state_dict(state_dict)
logger.info("Load prefix weight successfully")
Expand Down
14 changes: 9 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,11 +1775,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_

merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel

if (
not isinstance(self.model, PretrainedModel)
and not isinstance(self.model, LoRAModel)
and not isinstance(self.model, PrefixModelForCausalLM)
):
if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
# lugimzzz: Force merge_tensor_parallel to True for LoRA & Prefix Model until there is an option to merge params during training.
self.model.save_pretrained(
output_dir,
variant=self.args.weight_name_suffix,
merge_tensor_parallel=True,
is_main_process=self.args.should_save,
)
elif not isinstance(self.model, PretrainedModel):
if isinstance(unwrap_model(self.model), PretrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
Expand Down

0 comments on commit 8398a5f

Please sign in to comment.