Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/gather_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def model_should_run_on_event(model: str, event: str) -> bool:
We put higher priority and fast models to pull request and rest to push.
"""
if event == "pull_request":
return model in ["mv3", "vit"]
return model in ["mv3", "vit", "qwen2_5"] # TODO: remove, just to test the ci
elif event == "push":
# These are super slow. Only run it periodically
return model not in ["dl3", "edsr", "emformer_predict"]
Expand Down
9 changes: 9 additions & 0 deletions .ci/scripts/test_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ test_model() {
# Install requirements for llama vision.
bash examples/models/llama3_2_vision/install_requirements.sh
fi
if [[ "${MODEL_NAME}" == "qwen2_5" ]]; then
# Install requirements for export_llama
bash examples/models/llama/install_requirements.sh
# Test export_llama script: python3 -m examples.models.llama.export_llama.
# Use Llama random checkpoint with Qwen 2.5 1.5b model configuration.
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/qwen2_5/1_5b_config.json
run_portable_executor_runner
rm "./${MODEL_NAME}.pte"
fi
# python3 -m examples.portable.scripts.export --model_name="llama2" should works too
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}"
run_portable_executor_runner
Expand Down
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"resnet50": ("resnet", "ResNet50Model"),
"llava": ("llava", "LlavaModel"),
"efficient_sam": ("efficient_sam", "EfficientSAM"),
"qwen2_5": ("qwen2_5", "Qwen2_5Model"),
}

__all__ = [
Expand Down
13 changes: 10 additions & 3 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,16 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.max_batch_size = args.max_batch_size
self.max_context_len = args.max_context_len
self.dim = args.dim
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.attention_qkv_bias = args.attention_qkv_bias
self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
self.wk = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
self.wv = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

self.layer_id = layer_id
Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@
verbosity_setting = None


# All models that leverage the transformer architecture defined in llama_transformer.py.
EXECUTORCH_DEFINED_MODELS = [
"stories110m",
"llama2",
"llama3",
"llama3_1",
"llama3_2",
"static_llama",
"qwen2_5",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I accidentally deleted the original comment about ordering, but I was going to say that I think this is clearer to list all the llama models first

]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]

Expand Down
25 changes: 17 additions & 8 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,23 @@ def __init__(self, **kwargs):
eviction_batch_size=eviction_batch_size,
)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
) # self.model_ = Transformer(gptconf)
missing, unexpected = None, None
try:
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
) # self.model_ = Transformer(gptconf)
except RuntimeError as e:
Comment on lines +240 to +249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it doesn't error out when loading examples/models/llama/params/demo_rand_params.pth or any checkpoint that is incompatible with the model architecture. We also have no way to not specify a checkpoint, I looked into removing the default val for that arg but it's going to take some work since it's relied on internally in a lot of places

print(
"Could not load checkpoint into mode, defaulting to random uninitialized weights."
)
print(f"Error: {e}")
# Need to provide concrete (empty) values for meta-initialized tensors for quantization.
self.model_.to_empty(device="cpu")

if missing:
missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")]
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ModelArgs:
num_experts: int = 8 # Number of experts
num_activated_experts: int = 2 # Number of experts to activate
attention_type: str = "mha" # Attention type, registered in attention.py
attention_qkv_bias: bool = False
use_kv_cache: bool = False # Use key/value cache
use_sdpa_with_kv_cache_op: bool = (
False # Use custom sdpa op that updates kv cache in-place
Expand Down
11 changes: 7 additions & 4 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def apply_rotary_emb_to_k(
return xk_out.type_as(xk)


# Wrap apply_rotary_emb in a module to enable it to be module swapped out.
class RotaryEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -213,14 +214,20 @@ class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params

# Choose the appropriate RoPE implementation
if self.params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.precompute_freqs_cis = partial(
precompute_freqs_cis,
use_scaled=self.params.use_scaled_rope,
scale_factor=self.params.rope_scale_factor,
)
self.apply_rotary_emb = RotaryEmbedding()

# Precompute frequencies
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
(
Expand All @@ -232,10 +239,6 @@ def __init__(self, params: ModelArgs):
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
if self.params.use_hf_rope:
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.apply_rotary_emb = RotaryEmbedding()

def forward(
self,
Expand Down
14 changes: 14 additions & 0 deletions examples/models/qwen2_5/1_5b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"dim": 1536,
"ffn_dim_multiplier": 1,
"hidden_dim": 8960,
"n_heads": 12,
"n_kv_heads": 2,
"n_layers": 28,
"norm_eps": 1e-06,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 151936,
"use_hf_rope": true,
"attention_qkv_bias": true
}
14 changes: 14 additions & 0 deletions examples/models/qwen2_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.example.models.llama.model import Llama2Model


class Qwen2_5Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"Qwen2_5Model",
]
70 changes: 70 additions & 0 deletions examples/models/qwen2_5/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Dict

import torch

from torchtune.models.convert_weights import get_mapped_key

from torchtune.training import FullModelHFCheckpointer

# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_QWEN_2_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wq.bias": "layers.{}.attn.q_proj.bias",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wv.bias": "layers.{}.attn.v_proj.bias",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
}


def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _QWEN_2_FROM_META.items()}

for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

# 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733.
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/",
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="QWEN2",
)

print("Loading checkpoint")
sd = checkpointer.load_checkpoint()

# Convert from TorchTune to Meta (PyTorch native).
sd = qwen_2_tune_to_meta(sd["model"])

print("Saving checkpoint")
torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth")
Loading