Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LoKrModel Class to paddle.peft library #9269

Merged
merged 16 commits into from
Nov 27, 2024

Conversation

WhuanY
Copy link
Contributor

@WhuanY WhuanY commented Oct 15, 2024

PR types

New features

PR changes

Others

Description

Adding LoKrModel, LoKrLinear and LoKrConfig to support a new lora-like adapter. Current implementation only supports contains Linear Modules. Motivation and discussion on such PR issue is at: #9226

Please provide suggestions on the current implementation!

Copy link

paddle-bot bot commented Oct 15, 2024

Thanks for your contribution!

Copy link

codecov bot commented Oct 15, 2024

Codecov Report

Attention: Patch coverage is 80.61798% with 69 lines in your changes missing coverage. Please review.

Project coverage is 53.03%. Comparing base (f5ca96e) to head (ec91282).
Report is 4 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/peft/lokr/lokr_model.py 80.10% 38 Missing ⚠️
paddlenlp/peft/lokr/lokr_layers.py 71.28% 29 Missing ⚠️
paddlenlp/trainer/trainer.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9269      +/-   ##
===========================================
- Coverage    53.10%   53.03%   -0.08%     
===========================================
  Files          692      694       +2     
  Lines       110570   110254     -316     
===========================================
- Hits         58715    58470     -245     
+ Misses       51855    51784      -71     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@WhuanY
Copy link
Contributor Author

WhuanY commented Oct 16, 2024

@DesmonDay

按照建议我已经提交了只有Linear Layer的LoKr实现供你们整体查看。麻烦有时间审阅下并给出需要修改的意见。

@greycooker
Copy link
Contributor

好的收到,我们这边会尽快review


# This module is set to be in alignment with code design paradiam of ...utils.env

LOKR_WEIGHTS_NAME = "lokr_model_state.pdparams"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done this

@@ -0,0 +1,19 @@
# Copyright 2023-present the HuggingFace Inc. team.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyright注意修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done this

def add_lora_split_mapping(self, module_name, is_column=False):
self.lora_split_mapping[module_name] = is_column

def _get_tensor_parallel_mappings(self, config, is_split=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有使用tensor_parallel 和pipeline parallel先把没用到的相关的逻辑删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done this

self.quantized = True
if lokr_module is None:
raise ValueError("LoKr strategy only supports paddle.nn.Linear right now")
if getattr(lokr_module, "quant_weight", None) is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quant相关逻辑没用到的也先删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done this

@WhuanY
Copy link
Contributor Author

WhuanY commented Oct 17, 2024

辛苦了!第一次参与开源。可能错误较多,我会尽早修改问题,重新提交,供你们审阅。 Have a good day!

@lugimzzz lugimzzz closed this Oct 22, 2024
@lugimzzz lugimzzz reopened this Oct 22, 2024
@lugimzzz
Copy link
Contributor

lugimzzz commented Oct 22, 2024

辛苦了!第一次参与开源。可能错误较多,我会尽早修改问题,重新提交,供你们审阅。 Have a good day!

感谢对PaddleNLP的贡献,我们非常欢迎社区开发者参与到PaddleNLP的开发中来。我会在重新提交代码后尽快进行review,期待提交的代码能早日合入到项目中!

@lugimzzz
Copy link
Contributor

可以再次review请告知我,我会尽快开始review

@WhuanY
Copy link
Contributor Author

WhuanY commented Oct 30, 2024

可以再次review请告知我,我会尽快开始review

按照要求我已经

  1. 去掉了暂时不涉及的并行逻辑;
  2. 增加了disable_lokr参数和相应办法
  3. 增加了test/peft/lokr_model.py,并通过了基本的测试;
  4. 根据test过程中发现的bug更新了部分LoKrLinear,包括重置初始化方式、更正前向传播Bug。

目前我可以想到的接下来可以做的是:

  1. 在unified_checkpoint中支持LoKrModel
  2. 增加该适配器的合并参数脚本
    如有问题和改进方向请说明,辛苦了!

"model_name_or_path": "meta-llama/Meta-Llama-3-8B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/lokr_ckpts",
"lokr": true,
Copy link
Contributor

@lugimzzz lugimzzz Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.代码还未在https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/run_finetune.py 文件中添加相应的代码,没有看到lokr设为true时执行的逻辑。
2.添加后请相应同步文档,lokr运行方式以及对应新增参数的解释https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/finetune.md
3.模仿lora添加llm单测 https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/llm/test_lora.py
4.请参考vera和lora的脚本新增一个 merge_lokr_params.py https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/tools/merge_vera_params.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再次感谢对开源代码的贡献,代码库中LoKr算法实现没有问题,补充完大模型应用样例即可合入PaddleNLP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到!争取这周末完成上述四点,到时候我在远程仓库提交

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all done~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

@WhuanY
Copy link
Contributor Author

WhuanY commented Nov 16, 2024

可以再次review请告知我,我会尽快开始review

你好,应该可以开始review了,最近项目私下做了测试,通过功能跨平台精度对齐的任务证明已经没有算法错误

@lugimzzz
Copy link
Contributor

关注一下单测覆盖率,PaddleNLP-CI报错看起来是网络问题,我rerun了。这两个问题解决就可以合入了
image

@lugimzzz
Copy link
Contributor

lugimzzz commented Nov 22, 2024

需要解决一下冲突和单测覆盖率,即可合入 @WhuanY

@WhuanY
Copy link
Contributor Author

WhuanY commented Nov 27, 2024

你好,冲突和单测问题已经解决~看看还有什么需要修正的吗? @lugimzzz

Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!感谢对飞桨开源框架贡献❤️!

@lugimzzz lugimzzz merged commit 3ef14dc into PaddlePaddle:develop Nov 27, 2024
10 of 12 checks passed
@WhuanY
Copy link
Contributor Author

WhuanY commented Nov 27, 2024

My pleasure❤️😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants