-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Allow loading from original Mistral format #8168
Changes from 10 commits
a36ff88
a5472c5
ab85a2e
e0264e0
0c1b115
fa44979
2c0bc8c
e33e127
ffdcdaf
f13b059
f4557bb
0fbf430
8eaac25
a8091a6
b9ad975
928e47a
990ee19
4a1044b
b3814d5
bbfb843
302397f
1fc792e
0428eaa
6efb51c
5f05f1e
402d77c
a79832d
d46c359
5f41771
516b84a
43d051c
c947a16
7679950
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ | |
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download | ||
from safetensors.torch import load_file, safe_open, save_file | ||
from tqdm.auto import tqdm | ||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME | ||
|
||
from vllm.config import LoadConfig, ModelConfig | ||
from vllm.distributed import get_tensor_model_parallel_rank | ||
|
@@ -252,6 +251,7 @@ def download_weights_from_hf( | |
def download_safetensors_index_file_from_hf( | ||
model_name_or_path: str, | ||
cache_dir: Optional[str], | ||
index_file: str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super nit: I think it makes more sense to have cache_dir after index_file, similar to how hf_hub_download is called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. happy to change |
||
revision: Optional[str] = None, | ||
) -> None: | ||
"""Download hf safetensors index file from Hugging Face Hub. | ||
|
@@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf( | |
# Download the safetensors index file. | ||
hf_hub_download( | ||
repo_id=model_name_or_path, | ||
filename=SAFE_WEIGHTS_INDEX_NAME, | ||
filename=index_file, | ||
cache_dir=cache_dir, | ||
revision=revision, | ||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, | ||
) | ||
# If file not found on remote or locally, we should not fail since | ||
# only some models will have SAFE_WEIGHTS_INDEX_NAME. | ||
# only some models will have index_file. | ||
except huggingface_hub.utils.EntryNotFoundError: | ||
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) | ||
logger.info("No %s found in remote.", index_file) | ||
except huggingface_hub.utils.LocalEntryNotFoundError: | ||
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) | ||
logger.info("No %s found in local cache.", index_file) | ||
|
||
|
||
# For models like Mistral-7B-v0.3, there are both sharded | ||
# safetensors files and a consolidated safetensors file. | ||
# Passing both of these to the weight loader functionality breaks. | ||
# So, we use the SAFE_WEIGHTS_INDEX_NAME to | ||
# So, we use the index_file to | ||
# look up which safetensors files should be used. | ||
def filter_duplicate_safetensors_files(hf_weights_files: List[str], | ||
hf_folder: str) -> List[str]: | ||
hf_folder: str, | ||
index_file: str) -> List[str]: | ||
# model.safetensors.index.json is a mapping from keys in the | ||
# torch state_dict to safetensors file holding that weight. | ||
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) | ||
index_file_name = os.path.join(hf_folder, index_file) | ||
if not os.path.isfile(index_file_name): | ||
return hf_weights_files | ||
|
||
# Iterate through the weight_map (weight_name: safetensors files) | ||
# to identify weights that we should use. | ||
with open(index_file_name) as index_file: | ||
weight_map = json.load(index_file)["weight_map"] | ||
with open(index_file_name, "r") as f: | ||
weight_map = json.load(f)["weight_map"] | ||
weight_files_in_index = set() | ||
for weight_name in weight_map: | ||
weight_files_in_index.add( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): | |
"gate_proj": ("gate_up_proj", 0), | ||
"up_proj": ("gate_up_proj", 1), | ||
} | ||
# Mistral/Llama models can also be loaded with --load-format consolidated | ||
# from consolidated.safetensors checkpoints | ||
consolidated_mapping = { | ||
"layers": "model.layers", | ||
"attention": "self_attn", | ||
"wq": "q_proj", | ||
"wk": "k_proj", | ||
"wv": "v_proj", | ||
"wo": "o_proj", | ||
"attention_norm": "input_layernorm", | ||
"feed_forward": "mlp", | ||
"w1": "gate_proj", | ||
"w2": "down_proj", | ||
"w3": "up_proj", | ||
"ffn_norm": "post_attention_layernorm", | ||
"tok_embeddings": "model.embed_tokens", | ||
"output": "lm_head", | ||
"norm": "model.norm" | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a standard naming scheme used by Llama models as well? I have only seen Mistral models with these style of checkpoints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah the original LLama checkpoints have this naming as well: https://github.com/meta-llama/llama/blob/8fac8befd776bc03242fe7bc2236cdb41b6c609c/llama/model.py#L207 (guess most people use the HF format indeed though) |
||
|
||
def __init__( | ||
self, | ||
|
@@ -472,6 +491,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
] | ||
params_dict = dict(self.named_parameters()) | ||
for name, loaded_weight in weights: | ||
name, loaded_weight = self.maybe_remap_consolidated( | ||
name, loaded_weight) | ||
|
||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
if ("rotary_emb.cos_cached" in name | ||
|
@@ -549,3 +571,33 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: | |
else: | ||
raise RuntimeError("Self attention has no KV cache scaling " | ||
"factor attribute!") | ||
|
||
# This function is used to remap the consolidated format as | ||
# used by Mistral and Llama <=2 | ||
def maybe_remap_consolidated( | ||
self, name: str, | ||
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]: | ||
|
||
def permute(w, n_heads): | ||
attn_in = self.config.head_dim * n_heads | ||
attn_out = self.config.hidden_size | ||
|
||
return w.view(n_heads, attn_in // n_heads // 2, 2, | ||
attn_out).transpose(1, 2).reshape(attn_in, attn_out) | ||
|
||
mapping = self.consolidated_mapping | ||
modules = name.split(".") | ||
|
||
# rotary embeds should be sliced | ||
if "wk" in modules: | ||
loaded_weight = permute(loaded_weight, | ||
self.config.num_key_value_heads) | ||
elif "wq" in modules: | ||
loaded_weight = permute(loaded_weight, | ||
self.config.num_attention_heads) | ||
|
||
for item in modules: | ||
if item in mapping and mapping[item] not in name: | ||
name = name.replace(item, mapping[item]) | ||
|
||
return name, loaded_weight |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
import contextlib | ||
import json | ||
from pathlib import Path | ||
from typing import Any, Dict, Optional, Type, Union | ||
|
||
from huggingface_hub import hf_hub_download | ||
from transformers import GenerationConfig, PretrainedConfig | ||
from transformers.models.auto.image_processing_auto import ( | ||
get_image_processor_config) | ||
|
@@ -53,22 +55,26 @@ def get_config( | |
code_revision: Optional[str] = None, | ||
rope_scaling: Optional[dict] = None, | ||
rope_theta: Optional[float] = None, | ||
load_from_params: bool = False, | ||
**kwargs, | ||
) -> PretrainedConfig: | ||
|
||
# Separate model folder from file path for GGUF models | ||
|
||
is_gguf = check_gguf_file(model) | ||
if is_gguf: | ||
kwargs["gguf_file"] = Path(model).name | ||
model = Path(model).parent | ||
|
||
try: | ||
config = AutoConfig.from_pretrained( | ||
model, | ||
trust_remote_code=trust_remote_code, | ||
revision=revision, | ||
code_revision=code_revision, | ||
**kwargs) | ||
if load_from_params: | ||
config = load_params_config(model, revision) | ||
else: | ||
config = AutoConfig.from_pretrained( | ||
model, | ||
trust_remote_code=trust_remote_code, | ||
revision=revision, | ||
code_revision=code_revision, | ||
**kwargs) | ||
except ValueError as e: | ||
if (not trust_remote_code and | ||
"requires you to execute the configuration file" in str(e)): | ||
|
@@ -104,6 +110,56 @@ def get_config( | |
return config | ||
|
||
|
||
def load_params_config(model, revision) -> PretrainedConfig: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why you have this new config? It seems to have the same information you would have in the config.json, just named differently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main reason is because the original format is always stored in params.json which accompanies the consolidated.safetensors checkpoints. Guess there are two problems with config.json
|
||
# This function loads a params.json config which | ||
# should be used when loading models in consolidated format | ||
|
||
config_file_name = "params.json" | ||
|
||
config_path = Path(model) / config_file_name | ||
|
||
if not config_path.is_file(): | ||
config_path = Path( | ||
hf_hub_download(model, config_file_name, revision=revision)) | ||
|
||
with open(config_path, 'r') as file: | ||
config_dict = json.load(file) | ||
|
||
config_mapping = { | ||
"dim": "hidden_size", | ||
"norm_eps": "rms_norm_eps", | ||
"n_kv_heads": "num_key_value_heads", | ||
"n_layers": "num_hidden_layers", | ||
"n_heads": "num_attention_heads", | ||
"hidden_dim": "intermediate_size", | ||
} | ||
|
||
def recurse_elems(elem: Any): | ||
if isinstance(elem, dict): | ||
config_dict = {} | ||
for key, value in elem.items(): | ||
key = config_mapping.get(key, key) | ||
config_dict[key] = recurse_elems(value) | ||
return PretrainedConfig(**config_dict) | ||
else: | ||
return elem | ||
|
||
config_dict["model_type"] = config_dict.get("model_type", "transformer") | ||
config_dict["hidden_act"] = config_dict.get("activation", "silu") | ||
config_dict["max_position_embeddings"] = 32768 | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
config_dict["tie_word_embeddings"] = config_dict.get( | ||
"tie_embeddings", False) | ||
config_dict["torch_dtype"] = "bfloat16" | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if config_dict["model_type"] == "transformer": | ||
if "moe" in config_dict: | ||
config_dict["architectures"] = ["MixtralForCausalLM"] | ||
else: | ||
config_dict["architectures"] = ["MistralForCausalLM"] | ||
|
||
return recurse_elems(config_dict) | ||
|
||
|
||
def get_hf_image_processor_config( | ||
model: Union[str, Path], | ||
revision: Optional[str] = None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CONSOLIDATED is a bit broad as a name no ? Especially since there is SHARDED_STATE above, I think it could lead to confusion.
Thoughts on LoadFormat.MISTRAL ? This makes it clear that the intent is to have us support / maintain the integration of our models into vLLM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to "mistral" now - keen to hear what @simon-mo @mgoin think :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on calling it "MISTRAL" format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it's better to be explicit if possible