From 4a7ba4dd04c83a997fd146c16c86a69b1dfa1329 Mon Sep 17 00:00:00 2001 From: lucylq Date: Tue, 14 Oct 2025 15:21:14 -0700 Subject: [PATCH] add lora for mlp and unsloth --- .ci/scripts/test_llama_lora.sh | 2 +- examples/models/llama/attention.py | 7 ++- examples/models/llama/convert_weights.py | 61 +++++++++++++++++++ examples/models/llama/feed_forward.py | 55 +++++++++++++++++ examples/models/llama/install_requirements.sh | 3 +- examples/models/llama/llama_transformer.py | 8 ++- examples/models/llama/model.py | 40 ++++++++++-- examples/models/llama/model_args.py | 3 +- 8 files changed, 167 insertions(+), 12 deletions(-) create mode 100644 examples/models/llama/convert_weights.py diff --git a/.ci/scripts/test_llama_lora.sh b/.ci/scripts/test_llama_lora.sh index 63325aa7778..73efe096f8f 100644 --- a/.ci/scripts/test_llama_lora.sh +++ b/.ci/scripts/test_llama_lora.sh @@ -55,7 +55,7 @@ cmake_build_llama_runner # Constants. RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1" PROMPT="What happens if you eat watermelon seeds?" -EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C," +EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C and" # Export LoRA PTE file. MODEL_NAME="llama_3_2_1B_lora" diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 0c0176269b3..7b7691a2304 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -409,14 +409,17 @@ def __init__( ) self.wo = ( LoRALinear( - in_dim=args.n_kv_heads * args.head_dim, + in_dim=args.n_heads * args.head_dim, out_dim=args.dim, rank=args.r, alpha=args.lora_alpha, dropout=0.0, use_bias=args.attention_qkv_bias, ) - if args.target_modules is not None and "output_proj" in args.target_modules + if args.target_modules is not None + and ( + "output_proj" in args.target_modules or "o_proj" in args.target_modules + ) else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) ) diff --git a/examples/models/llama/convert_weights.py b/examples/models/llama/convert_weights.py new file mode 100644 index 00000000000..ab37f6a7218 --- /dev/null +++ b/examples/models/llama/convert_weights.py @@ -0,0 +1,61 @@ +from typing import Dict + +import torch + +from safetensors.torch import load_file +from torchtune.models.convert_weights import get_mapped_key + +_UNSLOTH_TO_META = { + "base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight", + "base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight", + "base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight", + "base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight", + "base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight", + "base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight", + "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight", + "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight", + "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight", + "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight", + "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight", + "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight", + "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight", + "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight", +} + + +def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from unsloth format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in unsloth format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + + for key, value in state_dict.items(): + try: + new_key = get_mapped_key(key, _UNSLOTH_TO_META) + except Exception as e: + raise ValueError(f"Key {key} not found in mapping") from e + + converted_state_dict[new_key] = value + return converted_state_dict + + +def load_and_convert_unsloth_to_meta(checkpoint_path: str) -> Dict[str, torch.Tensor]: + """ + Load a checkpoint file and convert it to Meta's format. + + Args: + checkpoint_path (str): Path to the checkpoint file. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + state_dict = load_file(checkpoint_path) + return unsloth_to_meta(state_dict) diff --git a/examples/models/llama/feed_forward.py b/examples/models/llama/feed_forward.py index 21a4e27df04..786567273c0 100644 --- a/examples/models/llama/feed_forward.py +++ b/examples/models/llama/feed_forward.py @@ -1,4 +1,7 @@ import torch.nn.functional as F + +from executorch.examples.models.llama.lora import LoRALinear +from executorch.examples.models.llama.model_args import ModelArgs from torch import nn @@ -11,3 +14,55 @@ def __init__(self, dim: int, hidden_dim: int): def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class LoRAFeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int, args: ModelArgs): + super().__init__() + + if args.r is None or args.lora_alpha is None: + raise ValueError( + "LoRA rank and alpha must be specified for LoRAFeedForward." + ) + + self.w1 = ( + LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=False, + ) + if "gate_proj" in args.target_modules + else nn.Linear(dim, hidden_dim, bias=False) + ) + + self.w2 = ( + LoRALinear( + in_dim=hidden_dim, + out_dim=dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=False, + ) + if "down_proj" in args.target_modules + else nn.Linear(hidden_dim, dim, bias=False) + ) + + self.w3 = ( + LoRALinear( + in_dim=dim, + out_dim=hidden_dim, + rank=args.r, + alpha=args.lora_alpha, + dropout=0.0, + use_bias=False, + ) + if "up_proj" in args.target_modules + else nn.Linear(dim, hidden_dim, bias=False) + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index 580a152a322..8a2fa25d244 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,7 +10,8 @@ # Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph # Install lm-eval for Model Evaluation with lm-evalution-harness. -pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile +# Install safetensors to load safetensors checkpoints (currently adapter only). +pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile safetensors # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 6587f7e1a10..6c1a5c05d66 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -18,7 +18,7 @@ AttentionSkip, ForwardOptions, ) -from executorch.examples.models.llama.feed_forward import FeedForward +from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope @@ -93,6 +93,12 @@ def __init__(self, args: ModelArgs, attention: Attention): ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." if args.moe: self.block_sparse_moe = MOEFeedForward(args) + elif args.target_modules is not None and ( + "down_proj" in args.target_modules + or "up_proj" in args.target_modules + or "gate_proj" in args.target_modules + ): + self.feed_forward = LoRAFeedForward(args.dim, args.hidden_dim, args) else: self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ac2905ea4c4..62f8c502cbd 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -15,6 +15,7 @@ get_checkpoint_dtype, get_default_model_resource_dir, ) + from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope @@ -140,14 +141,41 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): adapter_checkpoint = {} adapter_config = {} if adapter_checkpoint_path: - adapter_checkpoint = torch.load( - adapter_checkpoint_path, map_location=device, mmap=True - ) - from torchtune.models import convert_weights + if adapter_checkpoint_path.endswith(".pt"): + adapter_checkpoint = torch.load( + adapter_checkpoint_path, map_location=device, mmap=True + ) + from torchtune.models import convert_weights + + adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint) + elif adapter_checkpoint_path.endswith(".safetensors"): + from executorch.examples.models.llama.convert_weights import ( + load_and_convert_unsloth_to_meta, + ) + + adapter_checkpoint = load_and_convert_unsloth_to_meta( + adapter_checkpoint_path + ) + else: + raise ValueError( + f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}" + ) - adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint) with open(adapter_config_path, "r") as f: - adapter_config = json.loads(f.read()) + adapter_config_full = json.loads(f.read()) + if ( + "r" not in adapter_config_full + or "lora_alpha" not in adapter_config_full + or "target_modules" not in adapter_config_full + ): + raise ValueError( + "Adapter config must contain r, lora_alpha, and target_modules." + ) + adapter_config = { + "r": adapter_config_full["r"], + "lora_alpha": adapter_config_full["lora_alpha"], + "target_modules": adapter_config_full["target_modules"], + } checkpoint.update(adapter_checkpoint) output_prune_map = None diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 3f9d3d8f2af..20663c81e7d 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -106,7 +106,8 @@ class ModelArgs: # These arguments come directly from a torchtune adapter_config.json file. r: Optional[int] = None # Rank. lora_alpha: Optional[int] = None # Alpha. - # Eg. q_proj, k_proj, v_proj, output_proj + # Modules that we can apply lora adapters to. + # Eg. q_proj, k_proj, v_proj, output_proj/o_proj, down_proj, gate_proj, up_proj target_modules: Optional[list] = None peft_type: Optional[str] = None # PEFT type. base_model_name_or_path: Optional[str] = None # Base model name or path.