@@ -52,7 +52,7 @@ class LinearAdapter(nn.Linear):
52
52
orig_linear (nn.Module): the linear module to augment.
53
53
dim (int): lora's dim in_features -> dim -> out_features.
54
54
alpha (int): lora's scaling alpha.
55
- dropout (float): dropout prob (default: 0.1 ).
55
+ dropout (float): dropout prob (default: 0.0 ).
56
56
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
57
57
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
58
58
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
@@ -64,7 +64,7 @@ def __init__(
64
64
orig_linear ,
65
65
dim = 8 ,
66
66
alpha = 32 ,
67
- dropout = 0.1 ,
67
+ dropout = 0.0 ,
68
68
dropout_position = 'post' ,
69
69
lora_A_init_method = 'xavier' ,
70
70
lora_dtype = None ,
@@ -82,14 +82,22 @@ def __init__(
82
82
if orig_linear .bias is not None :
83
83
self .bias .data .copy_ (orig_linear .bias .data )
84
84
# initialize the adapte
85
- LinearAdapter ._init_adapter (self )
85
+ LinearAdapter ._init_adapter (
86
+ self ,
87
+ dim = dim ,
88
+ alpha = alpha ,
89
+ dropout = dropout ,
90
+ dropout_position = dropout_position ,
91
+ lora_A_init_method = lora_A_init_method ,
92
+ lora_dtype = lora_dtype ,
93
+ )
86
94
87
95
@staticmethod
88
96
def _init_adapter (
89
97
obj ,
90
98
dim = 8 ,
91
99
alpha = 32 ,
92
- dropout = 0.1 ,
100
+ dropout = 0.0 ,
93
101
dropout_position = 'post' ,
94
102
lora_A_init_method = 'xavier' ,
95
103
lora_dtype = None ,
@@ -101,7 +109,7 @@ def _init_adapter(
101
109
obj (LinearAdapter | nn.Module): input module to adapt.
102
110
dim (int): lora's dim in_features -> dim -> out_features.
103
111
alpha (int): lora's scaling alpha.
104
- dropout (float): dropout prob (default: 0.1 ).
112
+ dropout (float): dropout prob (default: 0.0 ).
105
113
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
106
114
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
107
115
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
@@ -155,7 +163,7 @@ def patch_linear_module(
155
163
orig_linear ,
156
164
dim = 8 ,
157
165
alpha = 32 ,
158
- dropout = 0.1 ,
166
+ dropout = 0.0 ,
159
167
dropout_position = 'post' ,
160
168
lora_A_init_method = 'xavier' ,
161
169
lora_dtype = None ,
@@ -175,7 +183,7 @@ def patch_linear_module(
175
183
orig_linear (nn.Linear): the module we add adapter to.
176
184
dim (int, optional): Lora dim. Defaults to 8.
177
185
alpha (int, optional): Lora alpha scale. Defaults to 32.
178
- dropout (float, optional): dropout prob. Defaults to 0.1 .
186
+ dropout (float, optional): dropout prob. Defaults to 0.0 .
179
187
dropout_position (str, optional): location to apply dropout wrt lora.
180
188
Defaults to 'post' (choices: 'pre', 'post').
181
189
lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'.
0 commit comments