Skip to content

Commit 03a0962

Browse files
orionrmalfet
authored andcommitted
[Llama3] Support Llama3 download from Hugging Face (pytorch#323)
1 parent 99c2f4b commit 03a0962

File tree

4 files changed

+34
-10
lines changed

4 files changed

+34
-10
lines changed

build/convert_hf_checkpoint.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,26 @@ def convert_hf_checkpoint(
3838
# Load the json file containing weight mapping
3939
model_map_json = model_dir / "pytorch_model.bin.index.json"
4040

41-
assert model_map_json.is_file()
41+
# If there is no weight mapping, check for a consolidated model and
42+
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
43+
# Llama 3 has a consolidated model and tokenizer.
44+
# Otherwise raise an error.
45+
if not model_map_json.is_file():
46+
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
47+
tokenizer_pth = model_dir / "original" / "tokenizer.model"
48+
if consolidated_pth.is_file() and tokenizer_pth.is_file():
49+
# Confirm we can load it
50+
loaded_result = torch.load(
51+
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
52+
)
53+
del loaded_result # No longer needed
54+
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
55+
os.rename(consolidated_pth, model_dir / "model.pth")
56+
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
57+
print("Done.")
58+
return
59+
else:
60+
raise RuntimeError(f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}")
4261

4362
with open(model_map_json) as json_map:
4463
bin_index = json.load(json_map)
@@ -111,7 +130,7 @@ def permute(w, n_heads):
111130
if __name__ == "__main__":
112131
import argparse
113132

114-
parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.")
133+
parser = argparse.ArgumentParser(description="Convert Hugging Face checkpoint.")
115134
parser.add_argument(
116135
"--checkpoint-dir",
117136
type=Path,

build/model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ModelArgs:
4242
multiple_of: int = 256
4343
ffn_dim_multiplier: Optional[int] = None
4444
use_tiktoken: Optional[bool] = None
45-
45+
4646
def __post_init__(self):
4747
if self.n_local_heads == -1:
4848
self.n_local_heads = self.n_heads
@@ -60,7 +60,7 @@ def __post_init__(self):
6060
if isinstance(self.use_tiktoken, str):
6161
self.use_tiktoken = (self.use_tiktoken == "True")
6262

63-
63+
6464
@classmethod
6565
def from_params(cls, params_path):
6666
replace = [("rope_theta", "rope_base"), ("n_kv_heads", "n_local_heads")]
@@ -85,19 +85,19 @@ def from_table(cls, name: str):
8585

8686
@classmethod
8787
def from_name(cls, name: str):
88-
print(f"name {name}")
88+
print(f"Name {name}")
8989
json_path=f"{config_dir}/{name}.json"
9090
if Path(json_path).is_file():
9191
return ModelArgs.from_params(json_path)
9292

9393
known_model_params = [config.replace(".json", "") for config in os.listdir(config_dir)]
9494

95-
print(f"known configs: {known_model_params}")
96-
# fuzzy search
95+
# Fuzzy search by name (e.g. "7B" and "Mistral-7B")
96+
print(f"Known configs: {known_model_params}")
9797
config = [
9898
config
9999
for config in known_model_params
100-
if config.replace in str(name).upper() or config in str(name)
100+
if config in str(name).upper() or config in str(name)
101101
]
102102

103103
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,

config/data/models.json

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
{
2+
"meta-llama/Meta-Llama-3-8B-Instruct": {
3+
"aliases": ["llama3", "llama3-8b"],
4+
"distribution_channel": "HuggingFaceSnapshot",
5+
"distribution_path": "meta-llama/Meta-Llama-3-8B-Instruct"
6+
},
27
"meta-llama/Llama-2-7b-chat-hf": {
38
"aliases": ["llama2", "llama2-7b"],
49
"distribution_channel": "HuggingFaceSnapshot",
510
"distribution_path": "meta-llama/Llama-2-7b-chat-hf"
611
},
712
"mistralai/Mistral-7B-Instruct-v0.2": {
8-
"aliases": ["mistral-7b-instruct"],
13+
"aliases": ["mistral-7b", "mistral-7b-instruct"],
914
"distribution_channel": "HuggingFaceSnapshot",
1015
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2"
1116
},

download.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _download_and_convert_hf_snapshot(
2626
from huggingface_hub import snapshot_download
2727

2828
# Download and store the HF model artifacts.
29-
print(f"Downloading {model} from HuggingFace...")
29+
print(f"Downloading {model} from Hugging Face...")
3030
try:
3131
snapshot_download(
3232
model,

0 commit comments

Comments
 (0)