Skip to content

Commit

Permalink
[Bug fix] fix LoRA unsync parameters issue (PaddlePaddle#6048)
Browse files Browse the repository at this point in the history
* fix styles

* remove extra print
  • Loading branch information
sijunhe authored May 30, 2023
1 parent 84a9548 commit 699f73a
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions paddlenlp/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,15 @@ def __init__(
is_bias=False,
attr=lora_A_weight_attr,
)
self.lora_A.is_distributed = False
self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
)
self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
self.scaling = self.lora_alpha / self.r

# Freezing the pre-trained weight matrix
Expand Down Expand Up @@ -184,8 +187,9 @@ def forward(self, input: paddle.Tensor):
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)

if self.r > 0 and not self.merged:
input_a = self.lora_dropout(input_mp) @ self.lora_A
delta_mp = (input_a @ self.lora_B) * self.scaling
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
Expand Down Expand Up @@ -378,13 +382,16 @@ def __init__(
is_bias=False,
attr=lora_A_weight_attr,
)
self.lora_A.is_distributed = False
# Make sure lora_B is split in column the same as ColumnParallelLoRALinear.
self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition // len(enable_lora) * sum(enable_lora)],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
)
self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
self.scaling = self.lora_alpha / self.r

# Freezing the pre-trained weight matrix
Expand Down Expand Up @@ -444,11 +451,12 @@ def forward(self, input: paddle.Tensor):
# [batch_size, *, out_features_per_partition]
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
if self.r > 0 and any(self.enable_lora) and not self.merged:
input_a = self.lora_dropout(input_mp) @ self.lora_A
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
if input_a.dim() == 3:
delta_mp = (
F.conv1d(
input_a.transpose([0, 2, 1]),
input_a_mp.transpose([0, 2, 1]),
self.lora_B.T.unsqueeze(-1),
groups=sum(self.enable_lora),
)
Expand Down

0 comments on commit 699f73a

Please sign in to comment.