Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_llama_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ cmake_build_llama_runner
# Constants.
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
PROMPT="What happens if you eat watermelon seeds?"
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C and"

# Export LoRA PTE file.
MODEL_NAME="llama_3_2_1B_lora"
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,14 +409,17 @@ def __init__(
)
self.wo = (
LoRALinear(
in_dim=args.n_kv_heads * args.head_dim,
in_dim=args.n_heads * args.head_dim,
out_dim=args.dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None and "output_proj" in args.target_modules
if args.target_modules is not None
and (
"output_proj" in args.target_modules or "o_proj" in args.target_modules
)
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
)

Expand Down
61 changes: 61 additions & 0 deletions examples/models/llama/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Dict

import torch

from safetensors.torch import load_file
from torchtune.models.convert_weights import get_mapped_key

_UNSLOTH_TO_META = {
"base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight",
"base_model.model.model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight",
"base_model.model.model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight",
"base_model.model.model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight",
"base_model.model.model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight",
"base_model.model.model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight",
"base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight",
"base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight",
"base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight",
"base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight",
"base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight",
"base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight",
"base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight",
"base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight",
}


def unsloth_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel like the file name is okay since this function is specifically named unsloth actually, follows the pattern for other models

"""
Convert a state dict from unsloth format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in unsloth format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}

for key, value in state_dict.items():
try:
new_key = get_mapped_key(key, _UNSLOTH_TO_META)
except Exception as e:
raise ValueError(f"Key {key} not found in mapping") from e

converted_state_dict[new_key] = value
return converted_state_dict


def load_and_convert_unsloth_to_meta(checkpoint_path: str) -> Dict[str, torch.Tensor]:
"""
Load a checkpoint file and convert it to Meta's format.

Args:
checkpoint_path (str): Path to the checkpoint file.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
state_dict = load_file(checkpoint_path)
return unsloth_to_meta(state_dict)
55 changes: 55 additions & 0 deletions examples/models/llama/feed_forward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import torch.nn.functional as F

from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from torch import nn


Expand All @@ -11,3 +14,55 @@ def __init__(self, dim: int, hidden_dim: int):

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class LoRAFeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, args: ModelArgs):
super().__init__()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate that args.r and args.lora_alpha must be specified

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we inherit from FeedForward instead and just overwrite the constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have ConditionalFeedForward and MOEFeedForward as separate nn.Modules (inside llama_transformer.py), so it seemed fitting to have this separate, but let me know what you think. @jackzhxng

if args.r is None or args.lora_alpha is None:
raise ValueError(
"LoRA rank and alpha must be specified for LoRAFeedForward."
)

self.w1 = (
LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=False,
)
if "gate_proj" in args.target_modules
else nn.Linear(dim, hidden_dim, bias=False)
)

self.w2 = (
LoRALinear(
in_dim=hidden_dim,
out_dim=dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=False,
)
if "down_proj" in args.target_modules
else nn.Linear(hidden_dim, dim, bias=False)
)

self.w3 = (
LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=False,
)
if "up_proj" in args.target_modules
else nn.Linear(dim, hidden_dim, bias=False)
)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
3 changes: 2 additions & 1 deletion examples/models/llama/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# Install tokenizers for hf .json tokenizer.
# Install snakeviz for cProfile flamegraph
# Install lm-eval for Model Evaluation with lm-evalution-harness.
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
# Install safetensors to load safetensors checkpoints (currently adapter only).
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile safetensors

# Call the install helper for further setup
python examples/models/llama/install_requirement_helper.py
8 changes: 7 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
AttentionSkip,
ForwardOptions,
)
from executorch.examples.models.llama.feed_forward import FeedForward
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.rope import Rope
Expand Down Expand Up @@ -93,6 +93,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
elif args.target_modules is not None and (
"down_proj" in args.target_modules
or "up_proj" in args.target_modules
or "gate_proj" in args.target_modules
):
self.feed_forward = LoRAFeedForward(args.dim, args.hidden_dim, args)
else:
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)

Expand Down
40 changes: 34 additions & 6 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_checkpoint_dtype,
get_default_model_resource_dir,
)

from executorch.examples.models.llama.llama_transformer import construct_transformer
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope
Expand Down Expand Up @@ -140,14 +141,41 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
adapter_checkpoint = {}
adapter_config = {}
if adapter_checkpoint_path:
adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
)
from torchtune.models import convert_weights
if adapter_checkpoint_path.endswith(".pt"):
adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
)
from torchtune.models import convert_weights

adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
elif adapter_checkpoint_path.endswith(".safetensors"):
from executorch.examples.models.llama.convert_weights import (
load_and_convert_unsloth_to_meta,
)

adapter_checkpoint = load_and_convert_unsloth_to_meta(
adapter_checkpoint_path
)
else:
raise ValueError(
f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}"
)

adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
with open(adapter_config_path, "r") as f:
adapter_config = json.loads(f.read())
adapter_config_full = json.loads(f.read())
if (
"r" not in adapter_config_full
or "lora_alpha" not in adapter_config_full
or "target_modules" not in adapter_config_full
):
raise ValueError(
"Adapter config must contain r, lora_alpha, and target_modules."
)
adapter_config = {
"r": adapter_config_full["r"],
"lora_alpha": adapter_config_full["lora_alpha"],
"target_modules": adapter_config_full["target_modules"],
}
checkpoint.update(adapter_checkpoint)

output_prune_map = None
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class ModelArgs:
# These arguments come directly from a torchtune adapter_config.json file.
r: Optional[int] = None # Rank.
lora_alpha: Optional[int] = None # Alpha.
# Eg. q_proj, k_proj, v_proj, output_proj
# Modules that we can apply lora adapters to.
# Eg. q_proj, k_proj, v_proj, output_proj/o_proj, down_proj, gate_proj, up_proj
target_modules: Optional[list] = None
peft_type: Optional[str] = None # PEFT type.
base_model_name_or_path: Optional[str] = None # Base model name or path.
Expand Down
Loading