Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
pissa: bool = False,
lora_use_mixer: bool = False,
use_mora: bool = False,
mp_moe: bool = False,
is_distributed: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -143,6 +145,10 @@ def __init__(
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
self.disable_lora = False
if mp_moe or is_distributed:
for p in self.parameters():
p.is_distributed = is_distributed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用于EP,is_distributed标识训练开始的时候不要同步参数和mp_moe用于uc

p.mp_moe = mp_moe

def pissa_init(self, rank):
weight = self.weight
Expand Down
163 changes: 33 additions & 130 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tempfile
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Union
from typing import Dict, Union

import aistudio_sdk
import numpy as np
Expand Down Expand Up @@ -97,38 +97,29 @@ def get_lora_layers():
LoRALinear = lora_layers["LoRALinear"]
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]

from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
RowParallelQuantizationLinear,
)
from .lora_quantization_layers import (
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
)

AVAILABLE_LAYERS = [
ColumnParallelLoRALinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
]
try:
from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
RowParallelQuantizationLinear,
)
from .lora_quantization_layers import (
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
)

AVAILABLE_LAYERS += [
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
]
except:
QuantizationLinear = None
ColumnParallelQuantizationLinear = None
RowParallelQuantizationLinear = None
QuantizationLoRALinear = None
ColumnParallelQuantizationLoRALinear = None
RowParallelQuantizationLoRALinear = None


class LoRAModel(nn.Layer):
Expand Down Expand Up @@ -426,11 +417,6 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal

if self.is_pipelinemodel:
self.model._single_to_pp_mapping = None
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
)
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
Expand Down Expand Up @@ -479,7 +465,7 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
model_config_to_save.tensor_parallel_degree = -1
model_config_to_save.save_pretrained(save_directory)

def _find_and_replace_module(self, model, module_name, lora_config, enable_lora):
def _find_and_replace_module(self, model, module_name, lora_config):
parent_module = model
attribute_chain = module_name.split(".")
for name in attribute_chain[:-1]:
Expand All @@ -500,14 +486,10 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
use_quick_lora=lora_config.use_quick_lora,
lora_use_mixer=lora_config.lora_use_mixer,
use_mora=lora_config.use_mora,
mp_moe=getattr(module.weight, "mp_moe", False),
is_distributed=getattr(module.weight, "is_distributed", False),
)
# Hack for mp group moe, need to find a better solution.
if getattr(module.weight, "mp_moe", False):
lora_module.lora_A.mp_moe = True
lora_module.lora_B.mp_moe = True
lora_module.lora_A.is_distributed = True
lora_module.lora_B.is_distributed = True
if isinstance(module, nn.Conv2D):
elif isinstance(module, nn.Conv2D):
lora_module = LoRAConv2D(
in_channels=module._in_channels,
out_channels=module._out_channels,
Expand Down Expand Up @@ -621,68 +603,20 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif QuantizationLinear is not None and isinstance(module, QuantizationLinear):
lora_module = QuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
block_size=module.block_size,
double_quant_block_size=module.double_quant_block_size,
double_quant=module.double_quant,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
self.quantized = True
elif ColumnParallelQuantizationLinear is not None and isinstance(module, ColumnParallelQuantizationLinear):
lora_module = ColumnParallelQuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
gather_output=module.gather_output,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.quantized = True
elif RowParallelQuantizationLinear is not None and isinstance(module, RowParallelQuantizationLinear):
lora_module = RowParallelQuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
input_is_parallel=module.input_is_parallel,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
self.quantized = True
elif isinstance(module, QuantizationLinear):
lora_module = QuantizationLoRALinear(module, lora_config)
elif isinstance(module, ColumnParallelQuantizationLinear):
lora_module = ColumnParallelQuantizationLoRALinear(module, lora_config)
elif isinstance(module, RowParallelQuantizationLinear):
lora_module = RowParallelQuantizationLoRALinear(module, lora_config)
if lora_module is None:
raise ValueError(
f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddlenlp.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supported。"
)
if getattr(lora_module, "quant_weight", None) is not None:
lora_module.quant_weight = module.quant_weight
if getattr(lora_module, "quant_scale", None) is not None:
lora_module.quant_scale = module.quant_scale
if getattr(lora_module, "qquant_scale", None) is not None:
lora_module.qquant_scale = module.qquant_scale
if getattr(lora_module, "double_quant_scale", None) is not None:
lora_module.double_quant_scale = module.double_quant_scale
if getattr(lora_module, "quant_sacle_offset", None) is not None:
lora_module.quant_sacle_offset = module.quant_sacle_offset
else:
if getattr(lora_module, "weight", None) is not None:
lora_module.weight = module.weight
if module.bias is not None:
lora_module.bias = module.bias
if module.bias is not None:
lora_module.bias = module.bias
setattr(parent_module, attribute_chain[-1], lora_module)

def _find_and_restore_module(self, module_name):
Expand Down Expand Up @@ -768,45 +702,14 @@ def get_lora_model(self, model: Union[PretrainedModel, nn.Layer], lora_config: L

if lora_config.target_modules is None:
return model
elif isinstance(lora_config.target_modules, str):
target_modules = [lora_config.target_modules]
if lora_config.enable_lora_list is None or (
isinstance(lora_config.enable_lora_list, List)
and all(isinstance(item, bool) for item in lora_config.enable_lora_list)
):
enable_lora_list = [lora_config.enable_lora_list]
else:
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `str`, `enable_lora_list` must be `None` or `List[bool]`"
)
else:
target_modules = lora_config.target_modules
if lora_config.enable_lora_list is None:
enable_lora_list = [None for _ in range(len(target_modules))]
elif isinstance(lora_config.enable_lora_list, List):
enable_lora_list = lora_config.enable_lora_list
if len(enable_lora_list) != len(target_modules):
raise TypeError(
f"Invalid lora_config.enable_lora_list value: {lora_config.enable_lora_list}. Since lora_config.target_modules is `List[str]`, `enable_lora_list` should have the same length as `target_modules`"
)
for enable_lora in enable_lora_list:
if not (
enable_lora is None
or (isinstance(enable_lora, List) and all(isinstance(item, bool) for item in enable_lora))
):
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `List[str]`, `enable_lora_list` must be `None` or `List[Optional[List[bool]]]`"
)
else:
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `List[str]`, `enable_lora_list` must be `None` or `List[Optional[List[bool]]]`"
)
if isinstance(lora_config.target_modules, str):
lora_config.target_modules = [lora_config.target_modules]

for target_module, enable_lora in zip(target_modules, enable_lora_list):
for target_module in lora_config.target_modules:
for i in model.named_sublayers():
module_name = i[0]
if re.fullmatch(target_module, module_name):
self._find_and_replace_module(model, module_name, lora_config, enable_lora)
self._find_and_replace_module(model, module_name, lora_config)
return model

def restore_original_model(self):
Expand Down
Loading
Loading