Skip to content

Commit 499161e

Browse files
authored
LinearAdapter: propagate args to _init_adapter (#11902)
* propagate defaults Signed-off-by: Alexandros Koumparoulis <[email protected]> * switch dropout default to 0.0 Signed-off-by: Alexandros Koumparoulis <[email protected]> * Apply isort and black reformatting Signed-off-by: akoumpa <[email protected]> --------- Signed-off-by: Alexandros Koumparoulis <[email protected]> Signed-off-by: akoumpa <[email protected]> Co-authored-by: akoumpa <[email protected]>
1 parent 0075ed0 commit 499161e

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

Diff for: nemo/collections/llm/peft/lora.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class LinearAdapter(nn.Linear):
5252
orig_linear (nn.Module): the linear module to augment.
5353
dim (int): lora's dim in_features -> dim -> out_features.
5454
alpha (int): lora's scaling alpha.
55-
dropout (float): dropout prob (default: 0.1).
55+
dropout (float): dropout prob (default: 0.0).
5656
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
5757
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
5858
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
@@ -64,7 +64,7 @@ def __init__(
6464
orig_linear,
6565
dim=8,
6666
alpha=32,
67-
dropout=0.1,
67+
dropout=0.0,
6868
dropout_position='post',
6969
lora_A_init_method='xavier',
7070
lora_dtype=None,
@@ -82,14 +82,22 @@ def __init__(
8282
if orig_linear.bias is not None:
8383
self.bias.data.copy_(orig_linear.bias.data)
8484
# initialize the adapte
85-
LinearAdapter._init_adapter(self)
85+
LinearAdapter._init_adapter(
86+
self,
87+
dim=dim,
88+
alpha=alpha,
89+
dropout=dropout,
90+
dropout_position=dropout_position,
91+
lora_A_init_method=lora_A_init_method,
92+
lora_dtype=lora_dtype,
93+
)
8694

8795
@staticmethod
8896
def _init_adapter(
8997
obj,
9098
dim=8,
9199
alpha=32,
92-
dropout=0.1,
100+
dropout=0.0,
93101
dropout_position='post',
94102
lora_A_init_method='xavier',
95103
lora_dtype=None,
@@ -101,7 +109,7 @@ def _init_adapter(
101109
obj (LinearAdapter | nn.Module): input module to adapt.
102110
dim (int): lora's dim in_features -> dim -> out_features.
103111
alpha (int): lora's scaling alpha.
104-
dropout (float): dropout prob (default: 0.1).
112+
dropout (float): dropout prob (default: 0.0).
105113
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
106114
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
107115
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
@@ -155,7 +163,7 @@ def patch_linear_module(
155163
orig_linear,
156164
dim=8,
157165
alpha=32,
158-
dropout=0.1,
166+
dropout=0.0,
159167
dropout_position='post',
160168
lora_A_init_method='xavier',
161169
lora_dtype=None,
@@ -175,7 +183,7 @@ def patch_linear_module(
175183
orig_linear (nn.Linear): the module we add adapter to.
176184
dim (int, optional): Lora dim. Defaults to 8.
177185
alpha (int, optional): Lora alpha scale. Defaults to 32.
178-
dropout (float, optional): dropout prob. Defaults to 0.1.
186+
dropout (float, optional): dropout prob. Defaults to 0.0.
179187
dropout_position (str, optional): location to apply dropout wrt lora.
180188
Defaults to 'post' (choices: 'pre', 'post').
181189
lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'.

0 commit comments

Comments
 (0)