Skip to content

Commit cd88afe

Browse files
committed
Add phi4 mini
1 parent 3ffd24e commit cd88afe

File tree

8 files changed

+141
-2
lines changed

8 files changed

+141
-2
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def model_should_run_on_event(model: str, event: str) -> bool:
9090
We put higher priority and fast models to pull request and rest to push.
9191
"""
9292
if event == "pull_request":
93-
return model in ["mv3", "vit"]
93+
return model in ["mv3", "vit", "phi4_mini"] # TODO: remove
9494
elif event == "push":
9595
# These are super slow. Only run it periodically
9696
return model not in ["dl3", "edsr", "emformer_predict"]

.ci/scripts/test_model.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ test_model() {
100100
rm "./${MODEL_NAME}.pte"
101101
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
102102
fi
103+
if [[ "${MODEL_NAME}" == "phi4_mini" ]]; then
104+
# Install requirements for export_llama
105+
bash examples/models/llama/install_requirements.sh
106+
# Test export_llama script: python3 -m examples.models.llama.export_llama.
107+
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi_4_mini/config.json
108+
run_portable_executor_runner
109+
rm "./${MODEL_NAME}.pte"
110+
fi
103111

104112
# Export a basic .pte and run the model.
105113
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}"

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"llava": ("llava", "LlavaModel"),
3636
"efficient_sam": ("efficient_sam", "EfficientSAM"),
3737
"qwen2_5": ("qwen2_5", "Qwen2_5Model"),
38+
"phi4_mini": ("phi4_mini", "Phi4MiniModel"),
3839
}
3940

4041
__all__ = [

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
"llama3_2",
9494
"static_llama",
9595
"qwen2_5",
96+
"phi4_mini",
9697
]
9798
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
9899

examples/models/llama/rope.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,21 @@ def forward(
134134

135135

136136
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
137-
def hf_precompute_freqs_cis(dim: int, end: int, theta: float):
137+
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
138+
# Current only support non-long rope.
139+
def hf_precompute_freqs_cis(
140+
dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0
141+
):
142+
# Partial rotary embeddings.
143+
dim = int(dim * partial_rotary_factor)
144+
145+
# Short factor scaling.
138146
freqs = 1.0 / (
139147
theta
140148
** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
141149
)
150+
# TODO: support long factor scaling.
151+
142152
# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
143153
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
144154
freqs # pyre-ignore
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"dim": 3072,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 8192,
5+
"n_heads": 24,
6+
"n_kv_heads": 8,
7+
"n_layers": 32,
8+
"norm_eps": 1e-05,
9+
"rope_theta": 10000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 200064,
12+
"use_hf_rope": true,
13+
"attention_qkv_bias": false
14+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import torch
5+
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
from torchtune.training import FullModelHFCheckpointer
9+
10+
11+
# Standard _FROM_META weight mapping of Meta weights to TorchTune.
12+
_PHI_4_FROM_META = {
13+
"tok_embeddings.weight": "tok_embeddings.weight",
14+
"norm.weight": "norm.scale",
15+
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
16+
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
17+
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
18+
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
19+
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
20+
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
21+
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
22+
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
23+
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
24+
}
25+
26+
27+
def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28+
"""
29+
Convert a state dict from torchtune's format to Meta's format. This function
30+
doesn't handle any sharding or splitting of state dicts. It follows the
31+
state_dict IN -> state_dict OUT pattern.
32+
33+
Args:
34+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
35+
36+
Returns:
37+
Dict[str, torch.Tensor]: State dict in Meta's format.
38+
"""
39+
converted_state_dict = {}
40+
inverted_mapping_dict = {v: k for k, v in _PHI_4_FROM_META.items()}
41+
42+
for key, value in state_dict.items():
43+
new_key = get_mapped_key(key, inverted_mapping_dict)
44+
converted_state_dict[new_key] = value
45+
46+
# Input and output embeddings are tied.
47+
converted_state_dict["output.weight"] = converted_state_dict[
48+
"tok_embeddings.weight"
49+
]
50+
51+
return converted_state_dict
52+
53+
54+
def main():
55+
parser = argparse.ArgumentParser(
56+
description="Convert Phi-4-mini weights to Meta format."
57+
)
58+
parser.add_argument(
59+
"input_dir",
60+
type=str,
61+
help="Path to directory containing checkpoint files",
62+
)
63+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
64+
65+
args = parser.parse_args()
66+
67+
checkpointer = FullModelHFCheckpointer(
68+
checkpoint_dir=args.input_dir,
69+
checkpoint_files=["model-00001-of-00003.safetensors", "model-00002-of-00003.safetensors", "model-00003-of-00003.safetensors"],
70+
output_dir=".",
71+
model_type="PHI4_MINI",
72+
)
73+
74+
print("Loading checkpoint...")
75+
sd = checkpointer.load_checkpoint()
76+
77+
print("Converting checkpoint...")
78+
sd = phi_4_tune_to_meta(sd["model"])
79+
80+
torch.save(sd, args.output)
81+
print(f"Checkpoint saved to {args.output}")
82+
83+
84+
if __name__ == "__main__":
85+
main()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
3+
4+
set_seed(2024)
5+
6+
prompt = "Tell me a story."
7+
8+
model_checkpoint = "microsoft/Phi-4-multimodal-instruct"
9+
10+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,trust_remote_code=True)
11+
model = AutoModelForCausalLM.from_pretrained(model_checkpoint,
12+
trust_remote_code=True,
13+
torch_dtype="auto",
14+
device_map="cpu")
15+
16+
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
17+
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=120)
18+
response= tokenizer.decode(outputs[0], skip_special_tokens=True)
19+
20+
print(response)

0 commit comments

Comments
 (0)