-
Notifications
You must be signed in to change notification settings - Fork 697
Add lora for mlp and unsloth #15132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add lora for mlp and unsloth #15132
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| from typing import Dict | ||
mergennachin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import torch | ||
|
|
||
| from safetensors.torch import load_file | ||
| from torchtune.models.convert_weights import get_mapped_key | ||
|
|
||
| _UNSLOTH_TO_META = { | ||
| "base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight", | ||
| "base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight", | ||
| "base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight", | ||
| "base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight", | ||
| "base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight", | ||
| "base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight", | ||
| "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight", | ||
| "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight", | ||
| "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight", | ||
| "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight", | ||
| "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight", | ||
| "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight", | ||
| "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight", | ||
| "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight", | ||
| } | ||
|
|
||
|
|
||
| def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i feel like the file name is okay since this function is specifically named unsloth actually, follows the pattern for other models |
||
| """ | ||
| Convert a state dict from unsloth format to Meta's format. This function | ||
| doesn't handle any sharding or splitting of state dicts. It follows the | ||
| state_dict IN -> state_dict OUT pattern. | ||
|
|
||
| Args: | ||
| state_dict (Dict[str, torch.Tensor]): State dict in unsloth format. | ||
|
|
||
| Returns: | ||
| Dict[str, torch.Tensor]: State dict in Meta's format. | ||
| """ | ||
| converted_state_dict = {} | ||
|
|
||
| for key, value in state_dict.items(): | ||
| try: | ||
| new_key = get_mapped_key(key, _UNSLOTH_TO_META) | ||
| except Exception as e: | ||
| raise ValueError(f"Key {key} not found in mapping") from e | ||
|
|
||
| converted_state_dict[new_key] = value | ||
| return converted_state_dict | ||
|
|
||
|
|
||
| def load_and_convert_unsloth_to_meta(checkpoint_path: str) -> Dict[str, torch.Tensor]: | ||
| """ | ||
| Load a checkpoint file and convert it to Meta's format. | ||
|
|
||
| Args: | ||
| checkpoint_path (str): Path to the checkpoint file. | ||
|
|
||
| Returns: | ||
| Dict[str, torch.Tensor]: State dict in Meta's format. | ||
| """ | ||
| state_dict = load_file(checkpoint_path) | ||
| return unsloth_to_meta(state_dict) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,7 @@ | ||
| import torch.nn.functional as F | ||
|
|
||
| from executorch.examples.models.llama.lora import LoRALinear | ||
| from executorch.examples.models.llama.model_args import ModelArgs | ||
| from torch import nn | ||
|
|
||
|
|
||
|
|
@@ -11,3 +14,55 @@ def __init__(self, dim: int, hidden_dim: int): | |
|
|
||
| def forward(self, x): | ||
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
|
||
|
|
||
| class LoRAFeedForward(nn.Module): | ||
| def __init__(self, dim: int, hidden_dim: int, args: ModelArgs): | ||
| super().__init__() | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validate that args.r and args.lora_alpha must be specified There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we inherit from FeedForward instead and just overwrite the constructor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have |
||
| if args.r is None or args.lora_alpha is None: | ||
| raise ValueError( | ||
| "LoRA rank and alpha must be specified for LoRAFeedForward." | ||
| ) | ||
|
|
||
| self.w1 = ( | ||
| LoRALinear( | ||
| in_dim=dim, | ||
| out_dim=hidden_dim, | ||
| rank=args.r, | ||
| alpha=args.lora_alpha, | ||
| dropout=0.0, | ||
| use_bias=False, | ||
| ) | ||
| if "gate_proj" in args.target_modules | ||
| else nn.Linear(dim, hidden_dim, bias=False) | ||
| ) | ||
|
|
||
| self.w2 = ( | ||
| LoRALinear( | ||
| in_dim=hidden_dim, | ||
| out_dim=dim, | ||
| rank=args.r, | ||
| alpha=args.lora_alpha, | ||
| dropout=0.0, | ||
| use_bias=False, | ||
| ) | ||
| if "down_proj" in args.target_modules | ||
| else nn.Linear(hidden_dim, dim, bias=False) | ||
| ) | ||
|
|
||
| self.w3 = ( | ||
| LoRALinear( | ||
| in_dim=dim, | ||
| out_dim=hidden_dim, | ||
| rank=args.r, | ||
| alpha=args.lora_alpha, | ||
| dropout=0.0, | ||
| use_bias=False, | ||
| ) | ||
| if "up_proj" in args.target_modules | ||
| else nn.Linear(dim, hidden_dim, bias=False) | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
Uh oh!
There was an error while loading. Please reload this page.