-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding LoKrModel Class to paddle.peft library #9269
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
21f222b
passing pre-commit
WhuanY 8054d8e
removing tp and pp logic for single gpu training
WhuanY 0e5a844
add disable_lokr attribute in lokr_layer
WhuanY 93f62e7
refine comments
WhuanY 67aba6c
add lokr tests and modified layer bug
WhuanY 3e2703d
add lokrtests
WhuanY aeaa619
add lokrtests
WhuanY 7860fed
add lokr_argument.json
WhuanY 98000cd
add integration test, fix bugs based on tests.
WhuanY 215beaa
refactor lora_dim to lokr_dim
WhuanY 4723293
no inference
WhuanY 0d72948
add more tests
WhuanY f74a545
resolve merge conflict
WhuanY d339509
Merge branch 'develop' into LoKrModel
WhuanY 056341f
add more randtests
WhuanY ec91282
pass isort check(maybe)
WhuanY File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"model_name_or_path": "meta-llama/Meta-Llama-3-8B", | ||
"dataset_name_or_path": "./data", | ||
"output_dir": "./checkpoints/lokr_ckpts", | ||
"lokr": true, | ||
"per_device_train_batch_size": 4, | ||
"gradient_accumulation_steps": 4, | ||
"num_train_epochs": 1, | ||
"learning_rate": 2e-05, | ||
"lr_scheduler_type": "linear", | ||
"attention_probs_dropout_prob": 0, | ||
"hidden_dropout_prob": 0, | ||
"warmup_steps": 30, | ||
"logging_steps": 1, | ||
"evaluation_strategy": "no", | ||
"save_strategy": "steps", | ||
"save_steps": 500, | ||
"src_length": 512, | ||
"max_length": 512, | ||
"bf16": true, | ||
"do_train": true, | ||
"do_eval": false, | ||
"disable_tqdm": false, | ||
"load_best_model_at_end": false, | ||
"eval_with_do_generation": false, | ||
"metric_for_best_model": "accuracy", | ||
"recompute": false, | ||
"save_total_limit": 100, | ||
"fp16_opt_level": "O2", | ||
"sharding": "stage2", | ||
"zero_padding": false, | ||
"use_flash_attention": false, | ||
"unified_checkpoint": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
import paddle | ||
|
||
from paddlenlp.peft import LoKrConfig, LoKrModel | ||
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||
from paddlenlp.utils.env import CONFIG_NAME | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.") | ||
parser.add_argument("--lokr_path", default="", help="The directory of lokr parameters. Default to None") | ||
parser.add_argument( | ||
"--merge_lokr_model_path", | ||
default="", | ||
help="The directory of merged parameters. Default to None", | ||
) | ||
parser.add_argument("--device", type=str, default="gpu", help="Device") | ||
parser.add_argument( | ||
"--low_gpu_mem", type=bool, default=True, help="Whether to use low gpu memory. Default to False" | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def weight_process(name, lokr_config, state_dict): | ||
weight = state_dict.pop(name + ".weight") | ||
use_w1 = True if ((name + ".lokr_w1") in state_dict) else False | ||
use_w2 = True if ((name + ".lokr_w2") in state_dict) else False | ||
if use_w1: | ||
lokr_w1 = state_dict.pop(name + ".lokr_w1") | ||
else: | ||
lokr_w1_a = state_dict.pop(name + ".lokr_w1_a") | ||
lokr_w1_b = state_dict.pop(name + ".lokr_w1_b") | ||
if use_w2: | ||
lokr_w2 = state_dict.pop(name + ".lokr_w2") | ||
else: | ||
lokr_w2_a = state_dict.pop(name + ".lokr_w2_a") | ||
lokr_w2_b = state_dict.pop(name + ".lokr_w2_b") | ||
|
||
scaling = lokr_config.lokr_alpha / lokr_config.lokr_dim | ||
|
||
adapter_weight = ( | ||
scaling | ||
* paddle.kron(lokr_w1 if use_w1 else lokr_w1_a @ lokr_w1_b, lokr_w2 if use_w2 else lokr_w2_a @ lokr_w2_b).T | ||
) | ||
state_dict[name + ".weight"] = weight + adapter_weight | ||
|
||
|
||
def merge(): | ||
args = parse_arguments() | ||
paddle.set_device(args.device) | ||
|
||
lokr_config = LoKrConfig.from_pretrained(args.lokr_path) | ||
if lokr_config.base_model_name_or_path is None: | ||
if args.model_name_or_path is not None: | ||
raise ValueError("We can not find a valid model_name_or_path.") | ||
else: | ||
lokr_config.base_model_name_or_path = args.model_name_or_path | ||
|
||
if os.path.isfile(os.path.join(args.lokr_path, CONFIG_NAME)): | ||
config = AutoConfig.from_pretrained(args.lokr_path) | ||
elif args.model_name_or_path is not None: | ||
config = AutoConfig.from_pretrained(args.model_name_or_path) | ||
else: | ||
raise ValueError( | ||
f"We can not find config.json in lokr_path: {args.lokr_path} or find a valid model_name_or_path." | ||
) | ||
config.dtype = lokr_config.dtype | ||
if ( | ||
lokr_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] | ||
) and args.device == "cpu": | ||
raise ValueError("We can not apply bfloat16 or nf4/fp4 lokr merge on cpu.") | ||
|
||
# with device_guard() will cause SVD decomposition to fail | ||
model = AutoModelForCausalLM.from_pretrained( | ||
lokr_config.base_model_name_or_path, | ||
config=config, | ||
low_cpu_mem_usage=True, | ||
) | ||
model = LoKrModel.from_pretrained(model=model, lokr_path=args.lokr_path, lokr_config=lokr_config) | ||
|
||
model.eval() | ||
model_state_dict = model.model.state_dict() | ||
lokr_name_list = [] | ||
|
||
for key in model_state_dict.keys(): | ||
if "lokr" in key: | ||
lokr_name_list.append(key.split(".lokr")[0]) | ||
|
||
lokr_name_list = list(set(lokr_name_list)) | ||
for name in lokr_name_list: | ||
weight_process(name, lokr_config, model_state_dict) | ||
|
||
model.model.save_pretrained(args.merge_lokr_model_path, state_dict=model_state_dict) | ||
tokenizer = AutoTokenizer.from_pretrained(lokr_config.base_model_name_or_path) | ||
tokenizer.save_pretrained(args.merge_lokr_model_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
merge() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
WhuanY marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .lokr_config import LoKrConfig | ||
from .lokr_layers import LoKrLinear | ||
from .lokr_model import LoKrModel | ||
|
||
__all__ = ["LoKrConfig", "LoKrModel", "LoKrLinear"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import os | ||
from dataclasses import asdict, dataclass, field | ||
from typing import List, Optional, Union | ||
|
||
from ...utils.env import LOKR_CONFIG_NAME | ||
|
||
|
||
@dataclass | ||
class LoKrConfig: | ||
""" | ||
This is the configuration class to store the configuration of a [`LoKrModel`]. | ||
Convention of LoKrModel: W1 can be named as scaling matrix, W2 can be named as adapter matrix. | ||
Args: | ||
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. | ||
trainable_modules (`List[str]`): The names of the modules to train when applying Lora. | ||
lokr_alpha (`float`): The alpha parameter for Lora scaling. | ||
merge_weights (`bool`): | ||
Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode. | ||
""" | ||
|
||
base_model_name_or_path: Optional[str] = field( | ||
default=None, metadata={"help": "The name of the base model to use."} | ||
) | ||
target_modules: Optional[Union[List[str], str]] = field( | ||
default=None, | ||
metadata={ | ||
"help": "List of module names or regex expression of the module names to replace with LoKr." | ||
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " | ||
}, | ||
) | ||
trainable_modules: Optional[List[str]] = field( | ||
default=None, | ||
metadata={ | ||
"help": "List of module names or regex expression of the module names to train when applying with LoKr." | ||
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " | ||
}, | ||
) | ||
trainable_bias: Optional[str] = field( | ||
default=None, metadata={"help": "Define trainable bias parameters for the Lora model."} | ||
) | ||
lokr_dim: int = field(default=8, metadata={"help": "Lora dimention in LoKr dimension, for adapter matrix"}) | ||
factor: int = field(default=-1, metadata={"help": "Determine the decomposition size of LoKr matrices"}) | ||
decompose_both: bool = field( | ||
default=False, | ||
metadata={"help": "Determine whether to decomposed both Scaling Matrix and adapter matrix together"}, | ||
) | ||
lokr_alpha: float = field( | ||
default=0.0, metadata={"help": "Determine the scaling of adapter weight, follow lokr convention"} | ||
) | ||
merge_weight: bool = field( | ||
default=False, metadata={"help": "Merge weights of the original model and the Lokr model"} | ||
) | ||
tensor_parallel_degree: int = field(default=-1, metadata={"help": "-1 for not use tensor parallel"}) | ||
dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"}) | ||
|
||
@property | ||
def __dict__(self): | ||
return asdict(self) | ||
|
||
def to_dict(self): | ||
return self.__dict__ | ||
|
||
@property | ||
def scaling(self): | ||
if not (self.lokr_alpha or self.lokr_dim): | ||
return 1.0 | ||
return self.lokr_alpha / self.lokr_dim | ||
|
||
def save_pretrained(self, save_directory): | ||
r""" | ||
This method saves the configuration of your adapter model in a directory. | ||
Args: | ||
save_directory (`str`): | ||
The directory where the configuration will be saved. | ||
""" | ||
if os.path.isfile(save_directory): | ||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") | ||
|
||
os.makedirs(save_directory, exist_ok=True) | ||
|
||
output_dict = self.__dict__ | ||
output_dict["scaling"] = self.scaling | ||
output_path = os.path.join(save_directory, LOKR_CONFIG_NAME) | ||
|
||
# save it | ||
with open(output_path, "w") as writer: | ||
writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | ||
r""" | ||
This method loads the configuration of your adapter model from a directory. | ||
Args: | ||
pretrained_model_name_or_path (`str`): | ||
The directory or the hub-id where the configuration is saved. | ||
**kwargs: | ||
Additional keyword arguments passed along to the child class initialization. | ||
""" | ||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, LOKR_CONFIG_NAME)): | ||
config_file = os.path.join(pretrained_model_name_or_path, LOKR_CONFIG_NAME) | ||
else: | ||
raise ValueError(f"Can't find lokr_config.json at '{pretrained_model_name_or_path}'") | ||
|
||
loaded_attributes = cls.from_json_file(config_file) | ||
loaded_attributes.pop("scaling", None) | ||
|
||
config = cls(**kwargs) | ||
|
||
for key, value in loaded_attributes.items(): | ||
if hasattr(config, key): | ||
setattr(config, key, value) | ||
|
||
return config | ||
|
||
@classmethod | ||
def from_json_file(cls, path_json_file): | ||
r""" | ||
Loads a configuration file from a json file. | ||
Args: | ||
path_json_file (`str`): | ||
The path to the json file. | ||
""" | ||
with open(path_json_file, "r") as file: | ||
json_object = json.load(file) | ||
|
||
return json_object |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1.代码还未在https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/run_finetune.py 文件中添加相应的代码,没有看到lokr设为true时执行的逻辑。
2.添加后请相应同步文档,lokr运行方式以及对应新增参数的解释https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/finetune.md
3.模仿lora添加llm单测 https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/llm/test_lora.py
4.请参考vera和lora的脚本新增一个 merge_lokr_params.py https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/tools/merge_vera_params.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再次感谢对开源代码的贡献,代码库中LoKr算法实现没有问题,补充完大模型应用样例即可合入PaddleNLP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,收到!争取这周末完成上述四点,到时候我在远程仓库提交
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all done~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
完成