Skip to content

Commit

Permalink
LinearAdapter: propagate args to _init_adapter (#11902)
Browse files Browse the repository at this point in the history
* propagate defaults

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* switch dropout default to 0.0

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
Signed-off-by: Abhinav Garg <[email protected]>
  • Loading branch information
2 people authored and abhinavg4 committed Jan 30, 2025
1 parent 4502816 commit f291eb6
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class LinearAdapter(nn.Linear):
orig_linear (nn.Module): the linear module to augment.
dim (int): lora's dim in_features -> dim -> out_features.
alpha (int): lora's scaling alpha.
dropout (float): dropout prob (default: 0.1).
dropout (float): dropout prob (default: 0.0).
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
Expand All @@ -64,7 +64,7 @@ def __init__(
orig_linear,
dim=8,
alpha=32,
dropout=0.1,
dropout=0.0,
dropout_position='post',
lora_A_init_method='xavier',
lora_dtype=None,
Expand All @@ -82,14 +82,22 @@ def __init__(
if orig_linear.bias is not None:
self.bias.data.copy_(orig_linear.bias.data)
# initialize the adapte
LinearAdapter._init_adapter(self)
LinearAdapter._init_adapter(
self,
dim=dim,
alpha=alpha,
dropout=dropout,
dropout_position=dropout_position,
lora_A_init_method=lora_A_init_method,
lora_dtype=lora_dtype,
)

@staticmethod
def _init_adapter(
obj,
dim=8,
alpha=32,
dropout=0.1,
dropout=0.0,
dropout_position='post',
lora_A_init_method='xavier',
lora_dtype=None,
Expand All @@ -101,7 +109,7 @@ def _init_adapter(
obj (LinearAdapter | nn.Module): input module to adapt.
dim (int): lora's dim in_features -> dim -> out_features.
alpha (int): lora's scaling alpha.
dropout (float): dropout prob (default: 0.1).
dropout (float): dropout prob (default: 0.0).
dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post)
lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform'])
lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they
Expand Down Expand Up @@ -155,7 +163,7 @@ def patch_linear_module(
orig_linear,
dim=8,
alpha=32,
dropout=0.1,
dropout=0.0,
dropout_position='post',
lora_A_init_method='xavier',
lora_dtype=None,
Expand All @@ -175,7 +183,7 @@ def patch_linear_module(
orig_linear (nn.Linear): the module we add adapter to.
dim (int, optional): Lora dim. Defaults to 8.
alpha (int, optional): Lora alpha scale. Defaults to 32.
dropout (float, optional): dropout prob. Defaults to 0.1.
dropout (float, optional): dropout prob. Defaults to 0.0.
dropout_position (str, optional): location to apply dropout wrt lora.
Defaults to 'post' (choices: 'pre', 'post').
lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'.
Expand Down

0 comments on commit f291eb6

Please sign in to comment.