diff --git a/llm/merge_lora_params.py b/llm/merge_lora_params.py index 50ae4a797f34..f61cad314db8 100644 --- a/llm/merge_lora_params.py +++ b/llm/merge_lora_params.py @@ -127,8 +127,8 @@ def merge(): config=config, low_cpu_mem_usage=True, ) - lora_config.merge_weights = True model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config) + model.merge() model.eval() model_state_dict = model.model.state_dict() for key in list(model_state_dict): diff --git a/llm/predictor.py b/llm/predictor.py index 3e9f47d8025e..b5e8f1dcda37 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -49,8 +49,8 @@ AutoConfig, AutoModelForCausalLM, AutoTokenizer, - ChatGLMv2Tokenizer, ChatGLMTokenizer, + ChatGLMv2Tokenizer, LlamaTokenizer, PretrainedModel, PretrainedTokenizer, @@ -242,7 +242,8 @@ def _preprocess(self, source): padding=True, # when use chat_template, it should not add special tokens # chatglm2 prefix-tokens can not be tokenized into ids - add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)), + add_special_tokens=self.tokenizer.chat_template is None + or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)), ) return tokenized_source @@ -272,7 +273,6 @@ def __init__( if config.lora_path is not None: lora_config = LoRAConfig.from_pretrained(config.lora_path) dtype = lora_config.dtype - lora_config.merge_weights = True elif config.prefix_path is not None: prefix_config = PrefixConfig.from_pretrained(config.prefix_path) dtype = prefix_config.dtype @@ -289,6 +289,7 @@ def __init__( tensor_parallel_degree=self.tensor_parallel_degree, tensor_parallel_rank=self.tensor_parallel_rank, ) + self.model.merge() if config.lora_path is not None: self.model = LoRAModel.from_pretrained( diff --git a/paddlenlp/peft/lora/__init__.py b/paddlenlp/peft/lora/__init__.py index cd736b245355..f1f83b9cdb48 100644 --- a/paddlenlp/peft/lora/__init__.py +++ b/paddlenlp/peft/lora/__init__.py @@ -13,11 +13,5 @@ # limitations under the License. from .lora_config import LoRAConfig -from .lora_layers import ( - ColumnParallelLoRALinear, - ColumnParallelLoRAMergedLinear, - LoRALinear, - LoRAMergedLinear, - RowParallelLoRALinear, -) +from .lora_layers import ColumnParallelLoRALinear, LoRALinear, RowParallelLoRALinear from .lora_model import LoRAModel diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index 12e3b929ed7e..3d6aa50764b9 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -94,6 +94,11 @@ def __post_init__(self): "We will automatically set `use_quick_lora` to `False` to avoid potential inconsistencies." ) self.use_quick_lora = False + if self.merge_weights: + logger.error( + "'merge_weights' is deprecated and will be removed in a future version. " + "Please apply model.merge() or model.unmerge() to merge/unmerge LoRA weight to base model." + ) @property def scaling(self): diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index 73120060fe87..cb670669ce38 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional +from typing import Optional import paddle import paddle.nn as nn @@ -63,7 +63,6 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - merge_weights: bool = True, use_quick_lora: bool = False, rslora: bool = False, lora_plus_scale: float = 1.0, @@ -82,7 +81,6 @@ def __init__( self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False - self.merge_weights = merge_weights self.pissa = pissa # Actual trainable parameters @@ -113,6 +111,7 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + self.disable_lora = False @property def use_quick_lora(self): @@ -137,34 +136,30 @@ def pissa_init(self, rank): weight = res.astype(dtype) self.weight.set_value(weight) - def train(self): - super().train() - if self.merge_weights and self.merged: - # Make sure that the weights are not merged - new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling - self.weight.set_value(new_weight) - self.merged = False - - def eval(self): - super().eval() - if self.merge_weights and not self.merged: - # Merge the weights and mark it + def merge(self): + if not self.merged: new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = True + def unmerge(self): + if self.merged: + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = False + def forward(self, input: paddle.Tensor, *args, **kwargs): if not self.apply_pissa and self.pissa: self.pissa_init(self.r) self.apply_pissa = True - - if self.use_quick_lora: + if self.disable_lora or self.merged: + result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name) + elif self.use_quick_lora: # Use the quick lora implementation result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling) else: result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name) - if not self.merged: - result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling + result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling return result def extra_repr(self): @@ -182,7 +177,6 @@ def __init__( lora_dropout: float = 0.0, rslora: bool = False, lora_plus_scale: float = 1.0, - merge_weights: bool = True, use_quick_lora: bool = False, pissa: bool = False, **kwargs @@ -203,7 +197,6 @@ def __init__( self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False - self.merge_weights = merge_weights # compatible self.name = self._name @@ -238,23 +231,20 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + self.disable_lora = False @property def use_quick_lora(self): return self._use_quick_lora and self.training and not self.merged - def train(self): - super().train() - if self.merge_weights and self.merged: - # Make sure that the weights are not merged + def unmerge(self): + if self.merged: new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: - # Merge the weights and mark it + def merge(self): + if not self.merged: new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = True @@ -264,8 +254,20 @@ def forward(self, x: paddle.Tensor): input_mp = mp_ops._c_split(x, group=self.model_parallel_group) else: input_mp = x - - if self.use_quick_lora: + if self.disable_lora or self.merged: + # x @ W : [bz, in_f / ws] ===> [bz, out_f] + if MC2RowParallelCoreLinear is None: + result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name) + output = mp_ops._mp_allreduce( + result_mp, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True, + ) + else: + output = MC2RowParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group) + output = output + self.bias if self.bias is not None else output + elif self.use_quick_lora: # Use the quick lora implementation result_mp = quick_lora( input_mp, @@ -297,19 +299,18 @@ def forward(self, x: paddle.Tensor): else: output = MC2RowParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group) - if not self.merged: - # x @ A: [bz, in_f/ ws] ===> [bz, r] - input_mp = self.lora_dropout(input_mp) @ self.lora_A - # all reduce to keep Lora B's gradient on different gpu consistent - input_dup = mp_ops._mp_allreduce( - input_mp, - group=self.model_parallel_group, - use_calc_stream=True, - use_model_parallel=True, - ) - # @ B: [bz, r] ===> [bz, out_f] - delta_mp = (input_dup @ self.lora_B) * self.scaling - output += delta_mp + # x @ A: [bz, in_f/ ws] ===> [bz, r] + input_mp = self.lora_dropout(input_mp) @ self.lora_A + # all reduce to keep Lora B's gradient on different gpu consistent + input_dup = mp_ops._mp_allreduce( + input_mp, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True, + ) + # @ B: [bz, r] ===> [bz, out_f] + delta_mp = (input_dup @ self.lora_B) * self.scaling + output += delta_mp output = output + self.bias if self.bias is not None else output return output @@ -328,7 +329,6 @@ def __init__( lora_dropout: float = 0.0, rslora: bool = False, lora_plus_scale: float = 1.0, - merge_weights: bool = True, use_quick_lora: bool = False, **kwargs ): @@ -344,7 +344,6 @@ def __init__( self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False - self.merge_weights = merge_weights # compatible self.name = self._name @@ -380,24 +379,21 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + self.disable_lora = False @property def use_quick_lora(self): # TODO(@gexiao): support qlora return False # self._use_quick_lora and self.training and not self.merged - def train(self): - super().train() - if self.merge_weights and self.merged: - # Make sure that the weights are not merged + def unmerge(self): + if self.merged: new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: - # Merge the weights and mark it + def merge(self): + if not self.merged: new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = True @@ -416,9 +412,10 @@ def forward(self, x: paddle.Tensor): output_ = MC2RowSeqParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group) result_mp = output_ + self.bias if self.bias is not None else output_ - if not self.merged: + if not self.merged and not self.disable_lora: input_mp = self.lora_dropout(input_mp) - if MC2RowSeqParallelCoreLinear is None: + # TODO(@gexiao): temporary workaround for deterministic calculation + if True or MC2RowSeqParallelCoreLinear is None: input_mp = input_mp @ self.lora_A input_mp = ReduceScatterOp.apply(input_mp) else: @@ -442,7 +439,6 @@ def __init__( lora_dropout: float = 0.0, rslora: bool = False, lora_plus_scale: float = 1.0, - merge_weights: bool = True, lora_A_weight_attr: Optional[paddle.ParamAttr] = None, use_quick_lora: bool = False, pissa: bool = False, @@ -464,7 +460,6 @@ def __init__( self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False - self.merge_weights = merge_weights # compatible self.name = self._name @@ -497,29 +492,36 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + self.disable_lora = False @property def use_quick_lora(self): return self._use_quick_lora and self.training and not self.merged - def train(self): - super().train() - if self.merge_weights and self.merged: + def unmerge(self): + if self.merged: # Make sure that the weights are not merged new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: + def merge(self): + if not self.merged: # Merge the weights and mark it new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = True def forward(self, input: paddle.Tensor): - if self.use_quick_lora: + if self.disable_lora or self.merged: + if MC2ColumnParallelCoreLinear is None: + input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) + result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name) + else: + res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group) + result_mp = (res_mp + self.bias) if self.bias is not None else res_mp + + elif self.use_quick_lora: # Use the quick lora implementation input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) if self.is_mp else input result_mp = quick_lora( @@ -541,15 +543,14 @@ def forward(self, input: paddle.Tensor): res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group) result_mp = (res_mp + self.bias) if self.bias is not None else res_mp - if not self.merged: - input_a = self.lora_dropout(input) @ self.lora_A - if MC2ColumnParallelCoreLinear is None: - input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) - delta_mp = (input_a_mp @ self.lora_B) * self.scaling - else: - tmp = MC2ColumnParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group) - delta_mp = tmp * self.scaling - result_mp += delta_mp + input_a = self.lora_dropout(input) @ self.lora_A + if MC2ColumnParallelCoreLinear is None: + input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) + delta_mp = (input_a_mp @ self.lora_B) * self.scaling + else: + tmp = MC2ColumnParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group) + delta_mp = tmp * self.scaling + result_mp += delta_mp if self.gather_output and self.is_mp: result = mp_ops._c_concat(result_mp, group=self.model_parallel_group) @@ -572,7 +573,6 @@ def __init__( lora_dropout: float = 0.0, rslora: bool = False, lora_plus_scale: float = 1.0, - merge_weights: bool = True, lora_A_weight_attr: Optional[paddle.ParamAttr] = None, use_quick_lora: bool = False, **kwargs @@ -589,7 +589,6 @@ def __init__( self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False - self.merge_weights = merge_weights # compatible self.name = self._name @@ -624,24 +623,21 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + self.disable_lora = False @property def use_quick_lora(self): # TODO(@gexiao): support qlora return False # self._use_quick_lora and self.training and not self.merged - def train(self): - super().train() - if self.merge_weights and self.merged: - # Make sure that the weights are not merged + def unmerge(self): + if self.merged: new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: - # Merge the weights and mark it + def merge(self): + if not self.merged: new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = True @@ -658,9 +654,10 @@ def forward(self, x: paddle.Tensor): if self.bias is not None: result_mp += self.bias - if not self.merged: + if not self.merged and not self.disable_lora: input_a = self.lora_dropout(x) @ self.lora_A - if MC2ColumnSeqParallelCoreLinear is None: + # TODO(@gexiao): temporary workaround for deterministic calculation + if True or MC2ColumnSeqParallelCoreLinear is None: input_a = AllGatherOp.apply(input_a) delta_mp = (input_a @ self.lora_B) * self.scaling else: @@ -679,356 +676,6 @@ def extra_repr(self): return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" -class LoRAMergedLinear(nn.Linear): - # LoRA implemented in a dense layer with merged linear weights for q, k, v - def __init__( - self, - in_features: int, - out_features: int, - head_dim: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - merge_weights: bool = True, - enable_lora: List[bool] = [False], - **kwargs - ): - nn.Linear.__init__(self, in_features, out_features, **kwargs) - assert ( - out_features % len(enable_lora) == 0 - ), f"The length of enable_lora must divide out_features: {out_features} % {len(enable_lora)} != 0" - if not isinstance(r, int) or r <= 0: - raise ValueError("Lora rank r should be a positive integer") - self.r = r - self.lora_alpha = lora_alpha - if isinstance(enable_lora, List) and all(isinstance(item, bool) for item in enable_lora): - self.enable_lora = enable_lora - else: - raise TypeError("enable_lora must be a list of bools") - - self.out_features = out_features - self.in_features = in_features - self.head_dim = head_dim - self.head_num = self.out_features // len(enable_lora) // self.head_dim - - # Optional dropout - if lora_dropout > 0.0 and any: - self.lora_dropout = nn.Dropout(p=lora_dropout) - else: - self.lora_dropout = lambda x: x - - # Mark the weight as unmerged - self.merged = False - self.merge_weights = merge_weights - - # Actual trainable parameters - if any(enable_lora): - self.lora_A = self.create_parameter( - shape=[in_features, r * sum(enable_lora)], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.KaimingUniform( - negative_slope=math.sqrt(5), nonlinearity="leaky_relu" - ), - ) - # Make sure lora_B is split in column for ColumnParallelLoRAMergedLinear. - self.lora_B = self.create_parameter( - shape=[r, out_features // len(enable_lora) * sum(enable_lora)], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) - self.scaling = self.lora_alpha / self.r - - # Freezing the pre-trained weight matrix - self.weight.stop_gradient = True - - def zero_pad_and_reshape(self, x): - # if enable_lora is all true, then there is no need to zero pad - if all(self.enable_lora): - output = x - else: - split_output = paddle.split(x, sum(self.enable_lora), axis=-1) - for index in range(len(self.enable_lora)): - if self.enable_lora[index] is False: - split_output.insert(index, paddle.zeros_like(split_output[0])) - output = paddle.concat(split_output, axis=-1) - if output.dim() == 2: - rank, out_features = output.shape - reshape_output = ( - output.reshape([rank, len(self.enable_lora), self.head_num, self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape([rank, out_features]) - ) - else: - batch, seq_len, out_features = output.shape - reshape_output = ( - output.reshape([batch, seq_len, len(self.enable_lora), self.head_num, self.head_dim]) - .transpose([0, 1, 3, 2, 4]) - .reshape([batch, seq_len, out_features]) - ) - - return reshape_output - - def train(self): - super().train() - if self.merge_weights and self.merged: - # Make sure that the weights are not merged - if any(self.enable_lora): - reshape_lora_B = ( - self.lora_B.reshape([self.r, self.head_num, sum(self.enable_lora), self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape(self.lora_B.shape) - ) - delta_weight = ( - F.conv1d( - self.lora_A.T.unsqueeze(0), - reshape_lora_B.T.unsqueeze(-1), - groups=sum(self.enable_lora), - ) - .squeeze(0) - .T - ) - new_weight = self.weight - self.zero_pad_and_reshape(delta_weight * self.scaling) - self.weight.set_value(new_weight) - self.merged = False - - def eval(self): - super().eval() - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if any(self.enable_lora): - reshape_lora_B = ( - self.lora_B.reshape([self.r, self.head_num, sum(self.enable_lora), self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape(self.lora_B.shape) - ) - delta_weight = ( - F.conv1d( - self.lora_A.T.unsqueeze(0), - reshape_lora_B.T.unsqueeze(-1), - groups=sum(self.enable_lora), - ) - .squeeze(0) - .T - ) - new_weight = self.weight + self.zero_pad_and_reshape(delta_weight * self.scaling) - self.weight.set_value(new_weight) - self.merged = True - - def forward(self, input: paddle.Tensor): - result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name) - if any(self.enable_lora) and not self.merged: - input_a = self.lora_dropout(input) @ self.lora_A - if input_a.dim() == 3: - reshape_lora_B = ( - self.lora_B.reshape([self.r, self.head_num, sum(self.enable_lora), self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape(self.lora_B.shape) - ) - delta = ( - F.conv1d( - input_a.transpose([0, 2, 1]), - reshape_lora_B.T.unsqueeze(-1), - groups=sum(self.enable_lora), - ) - ).transpose([0, 2, 1]) - else: - raise NotImplementedError("LoRAMergedLinear only support 3D input features") - - result += self.zero_pad_and_reshape(delta * self.scaling) - return result - - def extra_repr(self): - name = f", name={self.name}" if self.name else "" - return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" - - -class ColumnParallelLoRAMergedLinear(ColumnParallelLinear): - # LoRA implemented in a dense layer with merged linear weights for q, k, v - def __init__( - self, - in_features: int, - out_features: int, - head_dim: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - merge_weights: bool = True, - enable_lora: List[bool] = [False], - lora_A_weight_attr: Optional[paddle.ParamAttr] = None, - **kwargs - ): - ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs) - assert ( - self.output_size_per_partition % len(enable_lora) == 0 - ), f"The length of enable_lora must divide out_features: {self.output_size_per_partition} % {len(enable_lora)} != 0" - if not isinstance(r, int) or r <= 0: - raise ValueError("Lora rank r should be a positive integer") - self.r = r - self.lora_alpha = lora_alpha - if isinstance(enable_lora, List) and all(isinstance(item, bool) for item in enable_lora): - self.enable_lora = enable_lora - else: - raise TypeError("enable_lora must be a list of bools") - - self.out_features = out_features - self.in_features = in_features - self.head_dim = head_dim - self.head_num = self.output_size_per_partition // len(enable_lora) // self.head_dim - - # Optional dropout - if lora_dropout > 0.0 and any: - self.lora_dropout = nn.Dropout(p=lora_dropout) - else: - self.lora_dropout = lambda x: x - - # Mark the weight as unmerged - self.merged = False - self.merge_weights = merge_weights - - # compatible - self.name = self._name - - # Actual trainable parameters - if any(enable_lora): - self.lora_A = self.create_parameter( - shape=[in_features, r * sum(enable_lora)], - dtype=self._dtype, - is_bias=False, - attr=lora_A_weight_attr, - ) - self.lora_A.is_distributed = False - # Make sure lora_B is split in column the same as ColumnParallelLoRALinear. - self.lora_B = self.create_parameter( - shape=[r, self.output_size_per_partition // len(enable_lora) * sum(enable_lora)], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) - self.lora_B.is_distributed = True - self.lora_B.split_axis = 1 - self.scaling = self.lora_alpha / self.r - - # Freezing the pre-trained weight matrix - self.weight.stop_gradient = True - - def zero_pad_and_reshape(self, x): - # if enable_lora is all true, then there is no need to zero pad - if all(self.enable_lora): - output = x - else: - split_output = paddle.split(x, sum(self.enable_lora), axis=-1) - for index in range(len(self.enable_lora)): - if self.enable_lora[index] is False: - split_output.insert(index, paddle.zeros_like(split_output[0])) - output = paddle.concat(split_output, axis=-1) - if output.dim() == 2: - rank, out_features = output.shape - reshape_output = ( - output.reshape([rank, len(self.enable_lora), self.head_num, self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape([rank, out_features]) - ) - else: - batch, seq_len, out_features = output.shape - reshape_output = ( - output.reshape([batch, seq_len, len(self.enable_lora), self.head_num, self.head_dim]) - .transpose([0, 1, 3, 2, 4]) - .reshape([batch, seq_len, out_features]) - ) - - return reshape_output - - def train(self): - super().train() - if self.merge_weights and self.merged: - # Make sure that the weights are not merged - if any(self.enable_lora): - reshape_lora_B = ( - self.lora_B.reshape([self.r, self.head_num, sum(self.enable_lora), self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape(self.lora_B.shape) - ) - delta_weight = ( - F.conv1d( - self.lora_A.T.unsqueeze(0), - reshape_lora_B.T.unsqueeze(-1), - groups=sum(self.enable_lora), - ) - .squeeze(0) - .T - ) - new_weight = self.weight - self.zero_pad_and_reshape(delta_weight * self.scaling) - self.weight.set_value(new_weight) - self.merged = False - - def eval(self): - super().eval() - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if any(self.enable_lora): - reshape_lora_B = ( - self.lora_B.reshape([self.r, self.head_num, sum(self.enable_lora), self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape(self.lora_B.shape) - ) - delta_weight = ( - F.conv1d( - self.lora_A.T.unsqueeze(0), - reshape_lora_B.T.unsqueeze(-1), - groups=sum(self.enable_lora), - ) - .squeeze(0) - .T - ) - new_weight = self.weight + self.zero_pad_and_reshape(delta_weight * self.scaling) - self.weight.set_value(new_weight) - self.merged = True - - def forward(self, input: paddle.Tensor): - # [batch_size, *, in_features] - input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) - # [batch_size, *, out_features_per_partition] - result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name) - if any(self.enable_lora) and not self.merged: - input_a = self.lora_dropout(input) @ self.lora_A - input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) - if input_a.dim() == 3: - reshape_lora_B = ( - self.lora_B.reshape([self.r, self.head_num, sum(self.enable_lora), self.head_dim]) - .transpose([0, 2, 1, 3]) - .reshape(self.lora_B.shape) - ) - delta_mp = ( - F.conv1d( - input_a_mp.transpose([0, 2, 1]), - reshape_lora_B.T.unsqueeze(-1), - groups=sum(self.enable_lora), - ) - ).transpose([0, 2, 1]) - else: - raise NotImplementedError("LoRAMergedLinear only support 3D input features") - # [batch_size, *, out_features_per_partition] - result_mp += self.zero_pad_and_reshape(delta_mp * self.scaling) - - if self.gather_output and self.is_mp: - result_mp_list = paddle.split(result_mp, len(self.enable_lora), axis=-1) - result_list = [] - for result_mp in result_mp_list: - result_list.append(mp_ops._c_concat(result_mp, group=self.model_parallel_group)) - # [batch_size, *, out_features] - result = paddle.concat(result_list, axis=-1) - else: - result = result_mp - - return result - - def extra_repr(self): - name = f", name={self.name}" if self.name else "" - return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" - - class LoRAConv2D(nn.Conv2D): # LoRA implemented in a dense layer def __init__( @@ -1039,7 +686,6 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - merge_weights: bool = True, **kwargs ): nn.Conv2D.__init__(self, in_channels, out_channels, kernel_size, **kwargs) @@ -1054,7 +700,6 @@ def __init__( self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False - self.merge_weights = merge_weights # Actual trainable parameters lora_A = nn.Conv2D( @@ -1084,10 +729,10 @@ def __init__( self.weight.stop_gradient = True if self.bias is not None: self.bias.stop_gradient = True + self.disable_lora = False - def train(self): - super().train() - if self.merge_weights and self.merged: + def unmerge(self): + if self.merged: weight_A = self.lora_A.cast(dtype=self.weight.dtype) weight_B = self.lora_B.cast(dtype=self.weight.dtype) if self.weight.shape[2:4] == [1, 1]: @@ -1109,9 +754,8 @@ def train(self): self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: + def merge(self): + if not self.merged: weight_A = self.lora_A.cast(dtype=self.weight.dtype) weight_B = self.lora_B.cast(dtype=self.weight.dtype) if self.weight.shape[2:4] == [1, 1]: @@ -1136,7 +780,7 @@ def eval(self): def forward(self, input: paddle.Tensor, *args, **kwargs): previous_dtype = input.dtype result = super().forward(input) - if not self.merged: + if not self.merged and not self.disable_lora: result += ( self.lora_B_forward(self.lora_A_forward(self.lora_dropout(input.cast(dtype=self.lora_A.dtype)))) * self.scaling diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index ebadf39a6a55..bf69760ae69a 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -32,20 +32,6 @@ RowParallelLinear, ) -from ...transformers.conversion_utils import ConversionMixin -from ...transformers.model_utils import ( - PretrainedModel, - _add_variant, - _load_state_dict_into_model, - dtype_guard, - load_state_dict, -) -from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix -from ...utils.distributed import distributed_gather -from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME -from ...utils.log import logger -from .lora_config import LoRAConfig - try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( ColumnSequenceParallelLinear, @@ -61,17 +47,36 @@ class RowSequenceParallelLinear: pass +from ...transformers.conversion_utils import ConversionMixin +from ...transformers.model_utils import ( + PretrainedModel, + _add_variant, + _load_state_dict_into_model, + dtype_guard, + load_state_dict, +) +from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix +from ...utils.distributed import distributed_gather +from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME +from ...utils.log import logger +from .lora_config import LoRAConfig from .lora_layers import ( ColumnParallelLoRALinear, - ColumnParallelLoRAMergedLinear, ColumnSequenceParallelLoRALinear, LoRAConv2D, LoRALinear, - LoRAMergedLinear, RowParallelLoRALinear, RowSequenceParallelLoRALinear, ) +AVALIABLE_LAYERS = [ + ColumnParallelLoRALinear, + ColumnSequenceParallelLoRALinear, + LoRAConv2D, + LoRALinear, + RowParallelLoRALinear, + RowSequenceParallelLoRALinear, +] try: from ...quantization.quantization_linear import ( ColumnParallelQuantizationLinear, @@ -83,6 +88,12 @@ class RowSequenceParallelLinear: QuantizationLoRALinear, RowParallelQuantizationLoRALinear, ) + + AVALIABLE_LAYERS += [ + ColumnParallelQuantizationLoRALinear, + QuantizationLoRALinear, + RowParallelQuantizationLoRALinear, + ] except: QuantizationLinear = None ColumnParallelQuantizationLinear = None @@ -96,10 +107,8 @@ class LoRAModel(nn.Layer): # TODO:lugimzzz support restore in following PR restore_layer_map: Dict[nn.Layer, nn.Layer] = { LoRALinear: nn.Linear, - LoRAMergedLinear: nn.Linear, LoRAConv2D: nn.Conv2D, # ColumnParallelLoRALinear: ColumnParallelLinear, - # ColumnParallelLoRAMergedLinear: ColumnParallelLinear, # RowParallelLoRALinear: RowParallelLinear, # QuantizationLoRALinear: QuantizationLinear, } @@ -391,222 +400,177 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) parent_module = getattr(parent_module, name) module = getattr(parent_module, attribute_chain[-1]) lora_module = None - if enable_lora is None: - if isinstance(module, nn.Linear): - lora_module = LoRALinear( - in_features=module.weight.shape[0], - out_features=module.weight.shape[1], - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - merge_weights=lora_config.merge_weights, - rslora=lora_config.rslora, - lora_plus_scale=lora_config.lora_plus_scale, - pissa=lora_config.pissa, - bias_attr=False if module.bias is None else None, - use_quick_lora=lora_config.use_quick_lora, - ) - if isinstance(module, nn.Conv2D): - lora_module = LoRAConv2D( - in_channels=module._in_channels, - out_channels=module._out_channels, - kernel_size=module._kernel_size, - stride=module._stride, - padding=module._padding, - dilation=module._dilation, - groups=module._groups, - padding_mode=module._padding_mode, - data_format=module._data_format, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - merge_weights=lora_config.merge_weights, - bias_attr=module._bias_attr, - ) - elif isinstance(module, ColumnParallelLinear): - # recover the original output_features - output_features = module.weight.shape[1] * module.world_size - lora_module = ColumnParallelLoRALinear( - in_features=module.weight.shape[0], - out_features=output_features, - gather_output=module.gather_output, - has_bias=module.bias is not None, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - rslora=lora_config.rslora, - lora_plus_scale=lora_config.lora_plus_scale, - pissa=lora_config.pissa, - merge_weights=lora_config.merge_weights, - lora_A_weight_attr=paddle.ParamAttr( - initializer=nn.initializer.KaimingUniform( - negative_slope=math.sqrt(5), nonlinearity="leaky_relu" - ) - ), - use_quick_lora=lora_config.use_quick_lora, - ) - # Lora column parallel will spilt lora B matrix - self.add_lora_split_mapping(module_name + ".lora_B", is_column=True) - - # for lora qat - if self.lora_config.do_qat: - self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True) - self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) - self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) - elif isinstance(module, RowParallelLinear): - # recover the original output_features - lora_module = RowParallelLoRALinear( - in_features=module.weight.shape[0] * module.world_size, - out_features=module.weight.shape[1], - has_bias=module.bias is not None, - input_is_parallel=module.input_is_parallel, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - rslora=lora_config.rslora, - lora_plus_scale=lora_config.lora_plus_scale, - pissa=lora_config.pissa, - merge_weights=lora_config.merge_weights, - use_quick_lora=lora_config.use_quick_lora, - ) - # Lora column parallel will spilt lora A matrix - self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) - - # for lora qat - if self.lora_config.do_qat: - self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) - self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) - self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) - elif isinstance(module, ColumnSequenceParallelLinear): - # recover the original output_features - output_features = module.weight.shape[1] * module.world_size - lora_module = ColumnSequenceParallelLoRALinear( - in_features=module.weight.shape[0], - out_features=output_features, - gather_output=module.gather_output, - has_bias=module.bias is not None, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - rslora=lora_config.rslora, - lora_plus_scale=lora_config.lora_plus_scale, - merge_weights=lora_config.merge_weights, - lora_A_weight_attr=paddle.ParamAttr( - initializer=nn.initializer.KaimingUniform( - negative_slope=math.sqrt(5), nonlinearity="leaky_relu" - ) - ), - use_quick_lora=lora_config.use_quick_lora, - ) - # Lora column parallel will spilt lora B matrix - self.add_lora_split_mapping(module_name + ".lora_B", is_column=True) - - # for lora qat - if self.lora_config.do_qat: - self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True) - self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) - self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) - elif isinstance(module, RowSequenceParallelLinear): - # recover the original output_features - lora_module = RowSequenceParallelLoRALinear( - in_features=module.weight.shape[0] * module.world_size, - out_features=module.weight.shape[1], - has_bias=module.bias is not None, - input_is_parallel=module.input_is_parallel, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - rslora=lora_config.rslora, - lora_plus_scale=lora_config.lora_plus_scale, - merge_weights=lora_config.merge_weights, - use_quick_lora=lora_config.use_quick_lora, - ) - # Lora column parallel will spilt lora A matrix - self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) - - # for lora qat - if self.lora_config.do_qat: - self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) - self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) - self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) - elif QuantizationLinear is not None and isinstance(module, QuantizationLinear): - lora_module = QuantizationLoRALinear( - in_features=module.in_features, - out_features=module.out_features, - quant_algo=module.quant_algo, - dtype=module._dtype, - bias_attr=False if module.bias is None else None, - block_size=module.block_size, - double_quant_block_size=module.double_quant_block_size, - double_quant=module.double_quant, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - merge_weights=lora_config.merge_weights, - ) - self.quantized = True - elif ColumnParallelQuantizationLinear is not None and isinstance(module, ColumnParallelQuantizationLinear): - lora_module = ColumnParallelQuantizationLoRALinear( - in_features=module.in_features, - out_features=module.out_features, - quant_algo=module.quant_algo, - dtype=module._dtype, - bias_attr=False if module.bias is None else None, - gather_output=module.gather_output, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - lora_A_weight_attr=paddle.ParamAttr( - initializer=nn.initializer.KaimingUniform( - negative_slope=math.sqrt(5), nonlinearity="leaky_relu" - ) - ), - ) - self.quantized = True - elif RowParallelQuantizationLinear is not None and isinstance(module, RowParallelQuantizationLinear): - lora_module = RowParallelQuantizationLoRALinear( - in_features=module.in_features, - out_features=module.out_features, - quant_algo=module.quant_algo, - dtype=module._dtype, - bias_attr=False if module.bias is None else None, - input_is_parallel=module.input_is_parallel, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - ) - self.quantized = True - else: - if isinstance(module, nn.Linear): - lora_module = LoRAMergedLinear( - in_features=module.weight.shape[0], - out_features=module.weight.shape[1], - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - merge_weights=lora_config.merge_weights, - enable_lora=enable_lora, - head_dim=lora_config.head_dim, - ) - elif isinstance(module, ColumnParallelLinear): - # recover the original output_features - lora_module = ColumnParallelLoRAMergedLinear( - in_features=module.weight.shape[0], - out_features=module.weight.shape[1] * module.world_size, - gather_output=module.gather_output, - has_bias=module.bias is not None, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - merge_weights=lora_config.merge_weights, - enable_lora=enable_lora, - head_dim=lora_config.head_dim, - lora_A_weight_attr=paddle.ParamAttr( - initializer=nn.initializer.KaimingUniform( - negative_slope=math.sqrt(5), nonlinearity="leaky_relu" - ) - ), - ) + if isinstance(module, nn.Linear): + lora_module = LoRALinear( + in_features=module.weight.shape[0], + out_features=module.weight.shape[1], + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + pissa=lora_config.pissa, + bias_attr=False if module.bias is None else None, + use_quick_lora=lora_config.use_quick_lora, + ) + if isinstance(module, nn.Conv2D): + lora_module = LoRAConv2D( + in_channels=module._in_channels, + out_channels=module._out_channels, + kernel_size=module._kernel_size, + stride=module._stride, + padding=module._padding, + dilation=module._dilation, + groups=module._groups, + padding_mode=module._padding_mode, + data_format=module._data_format, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + bias_attr=module._bias_attr, + ) + elif isinstance(module, ColumnParallelLinear): + # recover the original output_features + output_features = module.weight.shape[1] * module.world_size + lora_module = ColumnParallelLoRALinear( + in_features=module.weight.shape[0], + out_features=output_features, + gather_output=module.gather_output, + has_bias=module.bias is not None, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + pissa=lora_config.pissa, + lora_A_weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") + ), + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora B matrix + self.add_lora_split_mapping(module_name + ".lora_B", is_column=True) + + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif isinstance(module, RowParallelLinear): + # recover the original output_features + lora_module = RowParallelLoRALinear( + in_features=module.weight.shape[0] * module.world_size, + out_features=module.weight.shape[1], + has_bias=module.bias is not None, + input_is_parallel=module.input_is_parallel, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + pissa=lora_config.pissa, + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora A matrix + self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) + + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif isinstance(module, ColumnSequenceParallelLinear): + # recover the original output_features + output_features = module.weight.shape[1] * module.world_size + lora_module = ColumnSequenceParallelLoRALinear( + in_features=module.weight.shape[0], + out_features=output_features, + gather_output=module.gather_output, + has_bias=module.bias is not None, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + lora_A_weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") + ), + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora B matrix + self.add_lora_split_mapping(module_name + ".lora_B", is_column=True) + + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif isinstance(module, RowSequenceParallelLinear): + # recover the original output_features + lora_module = RowSequenceParallelLoRALinear( + in_features=module.weight.shape[0] * module.world_size, + out_features=module.weight.shape[1], + has_bias=module.bias is not None, + input_is_parallel=module.input_is_parallel, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora A matrix + self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) + + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif QuantizationLinear is not None and isinstance(module, QuantizationLinear): + lora_module = QuantizationLoRALinear( + in_features=module.in_features, + out_features=module.out_features, + quant_algo=module.quant_algo, + dtype=module._dtype, + bias_attr=False if module.bias is None else None, + block_size=module.block_size, + double_quant_block_size=module.double_quant_block_size, + double_quant=module.double_quant, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + ) + self.quantized = True + elif ColumnParallelQuantizationLinear is not None and isinstance(module, ColumnParallelQuantizationLinear): + lora_module = ColumnParallelQuantizationLoRALinear( + in_features=module.in_features, + out_features=module.out_features, + quant_algo=module.quant_algo, + dtype=module._dtype, + bias_attr=False if module.bias is None else None, + gather_output=module.gather_output, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + lora_A_weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") + ), + ) + self.quantized = True + elif RowParallelQuantizationLinear is not None and isinstance(module, RowParallelQuantizationLinear): + lora_module = RowParallelQuantizationLoRALinear( + in_features=module.in_features, + out_features=module.out_features, + quant_algo=module.quant_algo, + dtype=module._dtype, + bias_attr=False if module.bias is None else None, + input_is_parallel=module.input_is_parallel, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + ) + self.quantized = True if lora_module is None: raise ValueError( f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddlenlp.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supported。" @@ -669,8 +633,6 @@ def mark_only_lora_as_trainable(self) -> None: or isinstance(layer, RowParallelLoRALinear) or isinstance(layer, ColumnSequenceParallelLoRALinear) or isinstance(layer, RowSequenceParallelLoRALinear) - or isinstance(layer, LoRAMergedLinear) - or isinstance(layer, ColumnParallelLoRAMergedLinear) or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear)) or ( ColumnParallelQuantizationLoRALinear is not None @@ -748,17 +710,14 @@ def get_lora_model(self, model: Union[PretrainedModel, nn.Layer], lora_config: L def restore_original_model(self): # make sure W and lora weights are not merged before we restore the original model - if self.lora_config.merge_weights: - self.train() for layer_name, layer in self.model.named_sublayers(): - if isinstance(layer, LoRALinear) or isinstance(layer, LoRAMergedLinear): + if isinstance(layer, LoRALinear): self._find_and_restore_module(layer_name) elif ( isinstance(layer, ColumnParallelLoRALinear) or isinstance(layer, ColumnSequenceParallelLoRALinear) or isinstance(layer, LoRAConv2D) - or isinstance(layer, ColumnParallelLoRAMergedLinear) or isinstance(layer, RowParallelLoRALinear) or isinstance(layer, RowSequenceParallelLoRALinear) or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear)) @@ -849,3 +808,23 @@ def save_to_aistudio( ) else: logger.info(f"{filename}: {res['message']}") + + def disable_lora(self): + for _, layer in self.model.named_sublayers(): + if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS): + layer.disable_lora = True + + def enable_lora(self): + for _, layer in self.model.named_sublayers(): + if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS): + layer.disable_lora = False + + def merge(self): + for _, layer in self.model.named_sublayers(): + if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS): + layer.merge() + + def unmerge(self): + for _, layer in self.model.named_sublayers(): + if any(isinstance(layer, lora_layer) for lora_layer in AVALIABLE_LAYERS): + layer.unmerge() diff --git a/paddlenlp/peft/lora/lora_quant_layers.py b/paddlenlp/peft/lora/lora_quant_layers.py index 462014dda9d8..6f4e7b2b703a 100644 --- a/paddlenlp/peft/lora/lora_quant_layers.py +++ b/paddlenlp/peft/lora/lora_quant_layers.py @@ -43,7 +43,6 @@ def __init__(self, layer: nn.Layer, q_config): # Mark the weight as unmerged self.merged = False - self.merge_weights = layer.merge_weights # For FakeQuant @@ -53,10 +52,11 @@ def __init__(self, layer: nn.Layer, q_config): self.weight_quanter = q_config.weight._instance(layer) if q_config.activation is not None: self.activation_quanter = q_config.activation._instance(layer) + self.disable_lora = False def forward(self, input): - if self.merge_weights and self.merged: + if self.merged or self.disable_lora: weight = self.weight else: weight = self.weight + self.lora_A @ self.lora_B * self.scaling @@ -71,17 +71,15 @@ def _linear_forward(self, input, weight): out = F.linear(x=input, weight=weight, bias=self.bias, name=self.name) return out - def train(self): - super().train() - if self.merge_weights and self.merged: + def unmerge(self): + if self.merged: # Make sure that the weights are not merged new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: + def merge(self): + if not self.merged: # Merge the weights and mark it new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) @@ -122,7 +120,6 @@ def __init__(self, layer: nn.Layer, q_config): # Mark the weight as unmerged self.merged = False - self.merge_weights = layer.merge_weights # For FakeQuant self.weight_quanter = None @@ -131,10 +128,11 @@ def __init__(self, layer: nn.Layer, q_config): self.weight_quanter = q_config.weight._instance(layer) if q_config.activation is not None: self.activation_quanter = q_config.activation._instance(layer) + self.disable_lora = False def forward(self, input): - if self.merge_weights and self.merged: + if self.merged or self.disable_lora: weight = self.weight else: weight = ( @@ -160,17 +158,15 @@ def _linear_forward(self, input, weight): result = result_mp return result - def train(self): - super().train() - if self.merge_weights and self.merged: + def unmerge(self): + if self.merged: # Make sure that the weights are not merged new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: + def merge(self): + if not self.merged: # Merge the weights and mark it new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) @@ -211,7 +207,6 @@ def __init__(self, layer: nn.Layer, q_config): # Mark the weight as unmerged self.merged = False - self.merge_weights = layer.merge_weights # For FakeQuant self.weight_quanter = None @@ -220,10 +215,11 @@ def __init__(self, layer: nn.Layer, q_config): self.weight_quanter = q_config.weight._instance(layer) if q_config.activation is not None: self.activation_quanter = q_config.activation._instance(layer) + self.disable_lora = False def forward(self, input): - if self.merge_weights and self.merged: + if self.merged or self.disable_lora: weight = self.weight else: weight = ( @@ -255,17 +251,15 @@ def _linear_forward(self, input, weight): output = output + self.bias if self.bias is not None else output return output - def train(self): - super().train() - if self.merge_weights and self.merged: + def unmerge(self): + if self.merged: # Make sure that the weights are not merged new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: + def merge(self): + if not self.merged: # Merge the weights and mark it new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling self.weight.set_value(new_weight) diff --git a/paddlenlp/peft/lora/lora_quantization_layers.py b/paddlenlp/peft/lora/lora_quantization_layers.py index 8c91a5fbbd85..8ff597633f1b 100644 --- a/paddlenlp/peft/lora/lora_quantization_layers.py +++ b/paddlenlp/peft/lora/lora_quantization_layers.py @@ -17,9 +17,9 @@ import paddle from paddle import nn from paddle.distributed.fleet.layers.mpu import mp_ops -from paddle.nn.quant import weight_dequantize, weight_only_linear +from paddle.nn.quant import weight_dequantize, weight_only_linear, weight_quantize -from ...quantization.qlora import qlora_weight_dequantize +from ...quantization.qlora import qlora_weight_dequantize, qlora_weight_quantize from ...quantization.quantization_linear import ( ColumnParallelQuantizationLinear, QuantizationLinear, @@ -56,10 +56,8 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - merge_weights: bool = True, ): - QuantizationLinear.__init__( - self, + super().__init__( in_features, out_features, quant_algo, @@ -76,6 +74,7 @@ def __init__( quant_scale_attr, llm_int8_threshold, ) + if not isinstance(r, int) or r <= 0: raise ValueError("Lora rank r should be a positive integer") if self.quant_algo == "llm.int8": @@ -86,7 +85,6 @@ def __init__( self.lora_alpha = lora_alpha # Mark the weight as unmerged self.merged = False - self.merge_weights = merge_weights # Optional dropout if lora_dropout > 0.0: self.lora_dropout = nn.Dropout(p=lora_dropout) @@ -108,15 +106,11 @@ def __init__( ) self.weight = None self.scaling = self.lora_alpha / self.r + self.disable_lora = False - def init_float_weight(self): - self.weight = self.create_parameter( - shape=[self.in_features, self.out_features], - dtype=self._dtype, - is_bias=False, - ) + def dequantize_weight(self): if self.quant_algo in ["fp4", "nf4"]: - qdq_weight = ( + new_weight = ( qlora_weight_dequantize( quant_weight=self.quant_weight, quant_algo=self.quant_algo, @@ -128,40 +122,61 @@ def init_float_weight(self): double_quant_block_size=self.double_quant_block_size, ) .cast(self._dtype) - .reshape(self.weight.shape) + .reshape([self.in_features, self.out_features]) + ) + elif self.quant_algo in ["weight_only_int8"]: + new_weight = weight_dequantize(self.quant_weight, self.quant_scale, self.quant_algo, self._dtype) + else: + raise NotImplementedError(f"{self.quant_algo} not yet support lora merge strategy.") + return new_weight + + def quantize_weight(self, new_weight): + if self.quant_algo in ["fp4", "nf4"]: + print("self.quant_weight", self.quant_weight) + quant_weight, quant_state = qlora_weight_quantize( + weight=new_weight, + quant_algo=self.quant_algo, + double_quant=self.double_quant, + block_size=self.block_size, + double_quant_block_size=self.double_quant_block_size, + return_dict=False, ) + print("quant_weight", quant_weight) + self.quant_weight.set_value(quant_weight) + if self.double_quant: + qquant_scale, double_quant_scale, quant_sacle_offset = quant_state + self.qquant_scale.set_value(qquant_scale) + self.double_quant_scale.set_value(double_quant_scale) + self.quant_sacle_offset.set_value(quant_sacle_offset) + else: + quant_scale = quant_state + self.quant_scale.set_value(quant_scale) elif self.quant_algo in ["weight_only_int8"]: - qdq_weight = weight_dequantize(self.quant_weight, self.quant_scale, self.quant_algo, self._dtype) + quant_weight, quant_scale = weight_quantize(new_weight, self.quant_algo) + self.quant_weight.set_value(quant_weight) + self.quant_scale.set_value(quant_scale) else: raise NotImplementedError(f"{self.quant_algo} not yet support lora merge strategy.") - self.weight.set_value(qdq_weight) - def train(self): - super().train() - if self.merge_weights and self.merged: + def unmerge(self): + if self.merged: # Make sure that the weights are not merged - new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling - self.weight.set_value(new_weight) + new_weight = self.dequantize_weight() + new_weight -= self.lora_A @ self.lora_B * self.scaling + self.quantize_weight(new_weight) self.merged = False - def eval(self): - super().eval() - if self.merge_weights and not self.merged: - if self.weight is None: - self.init_float_weight() + def merge(self): + if not self.merged: # Merge the weights and mark it - new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling - self.weight.set_value(new_weight) + new_weight = self.dequantize_weight() + new_weight += self.lora_A @ self.lora_B * self.scaling + self.quantize_weight(new_weight) self.merged = True def forward(self, x: paddle.Tensor): - if self.merge_weights: - if self.weight is None: - self.init_float_weight() - result = paddle.nn.functional.linear(x, self.weight, self.bias) - else: - result = super().forward(x) - if not self.merged: + result = super().forward(x) + if not self.merged and not self.disable_lora: result += (self.lora_dropout(x) @ self.lora_A @ self.lora_B) * self.scaling return result @@ -235,14 +250,19 @@ def __init__( self.lora_B.is_distributed = True self.lora_B.split_axis = 1 self.scaling = self.lora_alpha / self.r + self.disable_lora = False + # Mark the weight as unmerged + self.merged = False def forward(self, x): + result_mp = super().forward(x) - input_a = self.lora_dropout(x) @ self.lora_A - input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) - delta_mp = (input_a_mp @ self.lora_B) * self.scaling - result_mp += delta_mp + if not self.disable_lora or not self.merged: + input_a = self.lora_dropout(x) @ self.lora_A + input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) + delta_mp = (input_a_mp @ self.lora_B) * self.scaling + result_mp += delta_mp if self.gather_output and self.is_mp: result = mp_ops._c_concat(result_mp, group=self.model_parallel_group) @@ -250,6 +270,70 @@ def forward(self, x): result = result_mp return result + def dequantize_weight(self): + if self.quant_algo in ["fp4", "nf4"]: + new_weight = ( + qlora_weight_dequantize( + quant_weight=self.quant_weight, + quant_algo=self.quant_algo, + state=(self.qquant_scale, self.double_quant_scale, self.quant_scale_offset) + if self.double_quant + else self.quant_scale, + double_quant=self.double_quant, + block_size=self.block_size, + double_quant_block_size=self.double_quant_block_size, + ) + .cast(self._dtype) + .reshape([self.in_features, self.out_features]) + ) + elif self.quant_algo in ["weight_only_int8"]: + new_weight = weight_dequantize(self.quant_weight, self.quant_scale, self.quant_algo, self._dtype) + else: + raise NotImplementedError(f"{self.quant_algo} not yet support lora merge strategy.") + return new_weight + + def quantize_weight(self, new_weight): + if self.quant_algo in ["fp4", "nf4"]: + quant_weight, quant_state = qlora_weight_quantize( + weight=new_weight, + quant_algo=self.quant_algo, + double_quant=self.double_quant, + block_size=self.block_size, + double_quant_block_size=self.double_quant_block_size, + return_dict=False, + ) + self.quant_weight.set_value(quant_weight) + if self.double_quant: + qquant_scale, double_quant_scale, quant_sacle_offset = quant_state + self.qquant_scale.set_value(qquant_scale) + self.double_quant_scale.set_value(double_quant_scale) + self.quant_sacle_offset.set_value(quant_sacle_offset) + else: + quant_scale = quant_state + self.quant_scale.set_value(quant_scale) + elif self.quant_algo in ["weight_only_int8"]: + quant_weight, quant_scale = weight_quantize(new_weight, self.quant_algo) + self.quant_weight.set_value(quant_weight) + self.quant_scale.set_value(quant_scale) + else: + raise NotImplementedError(f"{self.quant_algo} not yet support lora merge strategy.") + + def unmerge(self): + if self.merged: + # Make sure that the weights are not merged + new_weight = self.dequantize_weight() + new_weight -= self.lora_A @ self.lora_B * self.scaling + self.quantize_weight(new_weight) + self.merged = False + + def merge(self): + if not self.merged: + # Merge the weights and mark it + new_weight = self.dequantize_weight() + new_weight += self.lora_A @ self.lora_B * self.scaling + self.quantize_weight(new_weight) + self.merged = True + class RowParallelQuantizationLoRALinear(RowParallelQuantizationLinear): """ @@ -320,6 +404,8 @@ def __init__( self.lora_A.split_axis = 0 self.lora_B.is_distributed = False self.scaling = self.lora_alpha / self.r + self.disable_lora = False + self.merged = False def forward(self, x: paddle.Tensor): if not self.input_is_parallel: @@ -337,18 +423,82 @@ def forward(self, x: paddle.Tensor): use_calc_stream=True, use_model_parallel=True, ) - - # x @ A: [bz, in_f/ ws] ===> [bz, r] - input_mp = self.lora_dropout(input_mp) @ self.lora_A - # all reduce to keep Lora B's gradient on different gpu consistent - input_dup = mp_ops._mp_allreduce( - input_mp, - group=self.model_parallel_group, - use_calc_stream=True, - use_model_parallel=True, - ) - # @ B: [bz, r] ===> [bz, out_f] - delta_mp = (input_dup @ self.lora_B) * self.scaling - output += delta_mp + if not self.disable_lora or not self.merged: + # x @ A: [bz, in_f/ ws] ===> [bz, r] + input_mp = self.lora_dropout(input_mp) @ self.lora_A + # all reduce to keep Lora B's gradient on different gpu consistent + input_dup = mp_ops._mp_allreduce( + input_mp, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True, + ) + # @ B: [bz, r] ===> [bz, out_f] + delta_mp = (input_dup @ self.lora_B) * self.scaling + output += delta_mp output = output + self.bias if self.bias is not None else output return output + + def dequantize_weight(self): + if self.quant_algo in ["fp4", "nf4"]: + new_weight = ( + qlora_weight_dequantize( + quant_weight=self.quant_weight, + quant_algo=self.quant_algo, + state=(self.qquant_scale, self.double_quant_scale, self.quant_scale_offset) + if self.double_quant + else self.quant_scale, + double_quant=self.double_quant, + block_size=self.block_size, + double_quant_block_size=self.double_quant_block_size, + ) + .cast(self._dtype) + .reshape([self.in_features, self.out_features]) + ) + elif self.quant_algo in ["weight_only_int8"]: + new_weight = weight_dequantize(self.quant_weight, self.quant_scale, self.quant_algo, self._dtype) + else: + raise NotImplementedError(f"{self.quant_algo} not yet support lora merge strategy.") + return new_weight + + def quantize_weight(self, new_weight): + if self.quant_algo in ["fp4", "nf4"]: + quant_weight, quant_state = qlora_weight_quantize( + weight=new_weight, + quant_algo=self.quant_algo, + double_quant=self.double_quant, + block_size=self.block_size, + double_quant_block_size=self.double_quant_block_size, + return_dict=False, + ) + self.quant_weight.set_value(quant_weight) + if self.double_quant: + qquant_scale, double_quant_scale, quant_sacle_offset = quant_state + self.qquant_scale.set_value(qquant_scale) + self.double_quant_scale.set_value(double_quant_scale) + self.quant_sacle_offset.set_value(quant_sacle_offset) + else: + quant_scale = quant_state + self.quant_scale.set_value(quant_scale) + elif self.quant_algo in ["weight_only_int8"]: + quant_weight, quant_scale = weight_quantize(new_weight, self.quant_algo) + self.quant_weight.set_value(quant_weight) + self.quant_scale.set_value(quant_scale) + else: + raise NotImplementedError(f"{self.quant_algo} not yet support lora merge strategy.") + + def unmerge(self): + if self.merged: + # Make sure that the weights are not merged + new_weight = self.dequantize_weight() + new_weight -= self.lora_A @ self.lora_B * self.scaling + self.quantize_weight(new_weight) + self.merged = False + + def merge(self): + if not self.merged: + # Merge the weights and mark it + new_weight = self.dequantize_weight() + new_weight += self.lora_A @ self.lora_B * self.scaling + self.quantize_weight(new_weight) + self.merged = True diff --git a/tests/peft/test_lora.py b/tests/peft/test_lora.py index f982c23defeb..3f2b22e5b359 100644 --- a/tests/peft/test_lora.py +++ b/tests/peft/test_lora.py @@ -22,7 +22,7 @@ import paddle from parameterized import parameterized -from paddlenlp.peft.lora import LoRAConfig, LoRALinear, LoRAMergedLinear, LoRAModel +from paddlenlp.peft.lora import LoRAConfig, LoRALinear, LoRAModel from paddlenlp.transformers import AutoModel, BertModel @@ -80,71 +80,12 @@ def test_load_regular_linear(self): self.assertTrue(paddle.allclose(lora_layer_r4(x), regular_linear(x))) -class TestLoRAMergedLayer(unittest.TestCase): - def test_forward(self): - lora_layer = LoRAMergedLinear( - in_features=16, out_features=8, r=4, lora_dropout=0.1, lora_alpha=8, enable_lora=[True, False], head_dim=2 - ) - x = paddle.randn([2, 4, 16], "float32") - output = lora_layer(x) - self.assertFalse(lora_layer.lora_A.stop_gradient) - self.assertFalse(lora_layer.lora_B.stop_gradient) - self.assertTrue(lora_layer.weight.stop_gradient) - self.assertFalse(lora_layer.bias.stop_gradient) - self.assertEqual(output.shape, [2, 4, 8]) - - def test_train_eval(self): - x = paddle.randn([2, 4, 16], "float32") - lora_layer = LoRAMergedLinear( - in_features=16, out_features=8, r=4, lora_alpha=8, enable_lora=[True, False], head_dim=2 - ) - lora_layer.train() - train_result = lora_layer(x) - train_weight = copy.deepcopy(lora_layer.weight) # deep copy since this is a pointer - lora_layer.eval() - eval_result = lora_layer(x) - eval_weight = lora_layer.weight - self.assertTrue(paddle.allclose(train_result, eval_result)) - self.assertTrue(paddle.allclose(train_weight, eval_weight)) - - def test_save_load(self): - with TemporaryDirectory() as tempdir: - lora_layer = LoRAMergedLinear( - in_features=16, out_features=8, r=4, lora_alpha=8, enable_lora=[True, False], head_dim=2 - ) - weights_path = os.path.join(tempdir, "model.pdparams") - paddle.save(lora_layer.state_dict(), weights_path) - new_lora_layer = LoRAMergedLinear( - in_features=16, out_features=8, r=4, lora_alpha=8, enable_lora=[True, False], head_dim=2 - ) - state_dict = paddle.load(weights_path) - new_lora_layer.set_dict(state_dict) - x = paddle.randn([2, 4, 16], "float32") - self.assertTrue(paddle.allclose(new_lora_layer(x), lora_layer(x))) - - def test_load_regular_linear(self): - with TemporaryDirectory() as tempdir: - regular_linear = paddle.nn.Linear(in_features=16, out_features=8) - weights_path = os.path.join(tempdir, "model.pdparams") - paddle.save(regular_linear.state_dict(), weights_path) - state_dict = paddle.load(weights_path) - # should be identical to regular linear - lora_layer_r8 = LoRAMergedLinear(in_features=16, out_features=8, r=8, head_dim=2) - lora_layer_r4 = LoRAMergedLinear(in_features=16, out_features=8, r=4, head_dim=2) - lora_layer_r8.set_dict(state_dict) - lora_layer_r4.set_dict(state_dict) - x = paddle.randn([2, 4, 16], "float32") - self.assertTrue(paddle.allclose(lora_layer_r8(x), regular_linear(x))) - self.assertTrue(paddle.allclose(lora_layer_r4(x), regular_linear(x))) - - class TestLoraModel(unittest.TestCase): def test_lora_model_restore(self): lora_config = LoRAConfig( target_modules=[".*q_proj.*", ".*v_proj.*"], r=4, lora_alpha=8, - merge_weights=True, enable_lora_list=[None, [True, False]], head_dim=2, ) @@ -167,7 +108,6 @@ def test_lora_model_constructor(self, bias): target_modules=[".*q_proj.*", ".*v_proj.*"], r=4, lora_alpha=8, - merge_weights=True, enable_lora_list=[None, [True, False]], trainable_bias=bias, head_dim=2, @@ -207,7 +147,6 @@ def test_lora_model_save_load(self): target_modules=[".*q_proj.*", ".*v_proj.*"], r=4, lora_alpha=8, - merge_weights=True, ) model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") lora_model = LoRAModel(model, lora_config) @@ -230,7 +169,6 @@ def test_lora_module_raise_exception(self): target_modules=[".*norm1.*"], r=4, lora_alpha=8, - merge_weights=True, enable_lora_list=None, ) model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") diff --git a/tests/peft/test_quant_lora.py b/tests/peft/test_quant_lora.py index 902ce14a6192..3219fb6b96cd 100644 --- a/tests/peft/test_quant_lora.py +++ b/tests/peft/test_quant_lora.py @@ -48,7 +48,6 @@ def test_forward_no_quant(self): out_features=8, r=4, lora_alpha=8, - merge_weights=True, ) quant_lora_layer = QuantedLoRALinear( layer=lora_layer, q_config=SingleLayerConfig(weight=None, activation=None) @@ -82,17 +81,17 @@ def test_save_load(self): self.assertTrue(paddle.allclose(new_quant_lora_layer(x), quant_lora_layer(x))) def test_merge_weights(self): - lora_layer = LoRALinear(in_features=16, out_features=8, r=4, lora_alpha=8, merge_weights=True) + lora_layer = LoRALinear(in_features=16, out_features=8, r=4, lora_alpha=8) quant_lora_layer = QuantedLoRALinear( layer=lora_layer, q_config=SingleLayerConfig(weight=None, activation=None) ) x = paddle.randn([2, 4, 16], "float32") - quant_lora_layer.train() - train_output = lora_layer(x) - quant_lora_layer.eval() - eval_output = lora_layer(x) - self.assertTrue(paddle.allclose(train_output, eval_output)) + quant_lora_layer.merge() + merge_output = lora_layer(x) + quant_lora_layer.unmerge() + unmerge_output = lora_layer(x) + self.assertTrue(paddle.allclose(merge_output, unmerge_output)) class TestQuantedLoRAModel(unittest.TestCase): @@ -140,8 +139,11 @@ def test_forward_no_quant(self): qat = QAT(q_config) self.lora_model.train() quant_lora_model = qat.quantize(self.lora_model, inplace=False) + quant_lora_model.merge() + self.lora_model.merge() quant_lora_model.eval() self.lora_model.eval() + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 5])) original_model_outputs = self.lora_model(input_ids)[0] quant_model_outputs = quant_lora_model(input_ids)[0]