Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LoKrModel Class to paddle.peft library #9269

Merged
merged 16 commits into from
Nov 27, 2024
34 changes: 34 additions & 0 deletions llm/config/llama/lokr_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"model_name_or_path": "meta-llama/Meta-Llama-3-8B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/lokr_ckpts",
"lokr": true,
Copy link
Contributor

@lugimzzz lugimzzz Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.代码还未在https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/run_finetune.py 文件中添加相应的代码,没有看到lokr设为true时执行的逻辑。
2.添加后请相应同步文档,lokr运行方式以及对应新增参数的解释https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/finetune.md
3.模仿lora添加llm单测 https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/llm/test_lora.py
4.请参考vera和lora的脚本新增一个 merge_lokr_params.py https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/tools/merge_vera_params.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再次感谢对开源代码的贡献,代码库中LoKr算法实现没有问题,补充完大模型应用样例即可合入PaddleNLP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,收到!争取这周末完成上述四点,到时候我在远程仓库提交

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all done~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"num_train_epochs": 1,
"learning_rate": 2e-05,
"lr_scheduler_type": "linear",
"attention_probs_dropout_prob": 0,
"hidden_dropout_prob": 0,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "no",
"save_strategy": "steps",
"save_steps": 500,
"src_length": 512,
"max_length": 512,
"bf16": true,
"do_train": true,
"do_eval": false,
"disable_tqdm": false,
"load_best_model_at_end": false,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": false,
"save_total_limit": 100,
"fp16_opt_level": "O2",
"sharding": "stage2",
"zero_padding": false,
"use_flash_attention": false,
"unified_checkpoint": true
}
19 changes: 19 additions & 0 deletions paddlenlp/peft/lokr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .lokr_config import LoKrConfig
from .lokr_layers import LoKrLinear
from .lokr_model import LoKrModel

__all__ = ["LoKrConfig", "LoKrModel", "LoKrLinear"]
141 changes: 141 additions & 0 deletions paddlenlp/peft/lokr/lokr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union

from ...utils.env import LOKR_CONFIG_NAME


@dataclass
class LoKrConfig:
"""
This is the configuration class to store the configuration of a [`LoKrModel`].
Convention of LoKrModel: W1 can be named as scaling matrix, W2 can be named as adapter matrix.
Args:
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
trainable_modules (`List[str]`): The names of the modules to train when applying Lora.
lokr_alpha (`float`): The alpha parameter for Lora scaling.
merge_weights (`bool`):
Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode.
"""

base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
)
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with Lora and Lora Variant."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
trainable_modules: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to train when applying with Lora and Lora Variant."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
trainable_bias: Optional[str] = field(
default=None, metadata={"help": "Define trainable bias parameters for the Lora model."}
)
lora_dim: int = field(default=8, metadata={"help": "Lora dimention in LoKr dimension, for adapter matrix"})
factor: int = field(default=-1, metadata={"help": "Determine the decomposition size of LoKr matrices"})
decompose_both: bool = field(
default=False,
metadata={"help": "Determine whether to decomposed both Scaling Matrix and adapter matrix together"},
)
lokr_alpha: float = field(
default=0.0, metadata={"help": "Determine the scaling of adapter weight, follow lokr convention"}
)
merge_weight: bool = field(
default=False, metadata={"help": "Merge weights of the original model and the Lokr model"}
)
tensor_parallel_degree: int = field(default=-1, metadata={"help": "-1 for not use tensor parallel"})
dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"})

@property
def __dict__(self):
return asdict(self)

def to_dict(self):
return self.__dict__

@property
def scaling(self):
if not (self.lokr_alpha or self.lora_dim):
return 1.0

Check warning on line 81 in paddlenlp/peft/lokr/lokr_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lokr/lokr_config.py#L81

Added line #L81 was not covered by tests
return self.lokr_alpha / self.lora_dim

def save_pretrained(self, save_directory):
r"""
This method saves the configuration of your adapter model in a directory.
Args:
save_directory (`str`):
The directory where the configuration will be saved.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

Check warning on line 92 in paddlenlp/peft/lokr/lokr_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lokr/lokr_config.py#L92

Added line #L92 was not covered by tests

os.makedirs(save_directory, exist_ok=True)

output_dict = self.__dict__
output_dict["scaling"] = self.scaling
output_path = os.path.join(save_directory, LOKR_CONFIG_NAME)

# save it
with open(output_path, "w") as writer:
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
This method loads the configuration of your adapter model from a directory.
Args:
pretrained_model_name_or_path (`str`):
The directory or the hub-id where the configuration is saved.
**kwargs:
Additional keyword arguments passed along to the child class initialization.
"""
if os.path.isfile(os.path.join(pretrained_model_name_or_path, LOKR_CONFIG_NAME)):
config_file = os.path.join(pretrained_model_name_or_path, LOKR_CONFIG_NAME)
else:
raise ValueError(f"Can't find lokr_config.json at '{pretrained_model_name_or_path}'")

Check warning on line 117 in paddlenlp/peft/lokr/lokr_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lokr/lokr_config.py#L117

Added line #L117 was not covered by tests

loaded_attributes = cls.from_json_file(config_file)
loaded_attributes.pop("scaling", None)

config = cls(**kwargs)

for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)

return config

@classmethod
def from_json_file(cls, path_json_file):
r"""
Loads a configuration file from a json file.
Args:
path_json_file (`str`):
The path to the json file.
"""
with open(path_json_file, "r") as file:
json_object = json.load(file)

return json_object
Loading
Loading