From 26d1aff10f2dbd85abc4ecc67d98c25aafbe8d30 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 7 Sep 2024 01:02:05 +0200 Subject: [PATCH] [Model] Allow loading from original Mistral format (#8168) Co-authored-by: Michael Goin Signed-off-by: Amit Garg --- tests/models/test_mistral.py | 40 +++++ vllm/config.py | 62 ++++--- vllm/engine/arg_utils.py | 21 ++- vllm/model_executor/model_loader/loader.py | 12 +- .../model_loader/weight_utils.py | 21 +-- vllm/model_executor/models/llama.py | 51 ++++++ vllm/transformers_utils/config.py | 165 ++++++++++++++---- 7 files changed, 291 insertions(+), 81 deletions(-) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 4965354c0016b..0741174497e32 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -41,3 +41,43 @@ def test_models( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS[1:]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="auto", + load_format="safetensors", + config_format="hf", + ) as hf_format_model: + hf_format_outputs = hf_format_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", + ) as mistral_format_model: + mistral_format_outputs = mistral_format_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_format_outputs, + outputs_1_lst=mistral_format_outputs, + name_0="hf", + name_1="mistral", + ) diff --git a/vllm/config.py b/vllm/config.py index 1c9e30b2682b9..8f5e02e35f28d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,7 +13,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import (get_config, +from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, @@ -121,35 +121,37 @@ class ModelConfig: override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that can not be gathered from the vllm arguments. + config_format: The config format which shall be loaded. + Defaults to 'auto' which defaults to 'hf'. """ - def __init__( - self, - model: str, - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - override_neuron_config: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + override_neuron_config: Optional[Dict[str, Any]] = None, + config_format: ConfigFormat = ConfigFormat.AUTO) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -176,7 +178,8 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling, rope_theta) + code_revision, rope_scaling, rope_theta, + config_format) self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) @@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum): SHARDED_STATE = "sharded_state" GGUF = "gguf" BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0b866db64324..7620093660b43 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,10 +8,10 @@ import torch import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, +from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, + DeviceConfig, EngineConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger @@ -65,6 +65,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + config_format: str = 'auto' dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -234,6 +235,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument( + '--config-format', + default=EngineArgs.config_format, + choices=[f.value for f in ConfigFormat], + help='The format of the model config to load.\n\n' + '* "auto" will try to load the config in hf format ' + 'if available else it will try to load in mistral format ') parser.add_argument( '--dtype', type=str, @@ -813,7 +821,10 @@ def create_engine_config(self) -> EngineConfig: served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, use_async_output_proc=not self.disable_async_output_proc, - override_neuron_config=self.override_neuron_config) + override_neuron_config=self.override_neuron_config, + config_format=self.config_format, + ) + cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 553fa848489b2..bcc866a194637 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -17,6 +17,7 @@ from huggingface_hub import HfApi, hf_hub_download from torch import nn from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, @@ -241,12 +242,17 @@ def _prepare_weights(self, model_name_or_path: str, is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" elif load_format == LoadFormat.PT: allow_patterns = ["*.pt"] elif load_format == LoadFormat.NPCACHE: @@ -284,10 +290,10 @@ def _prepare_weights(self, model_name_or_path: str, # any files not found in the index. if not is_local: download_safetensors_index_file_from_hf( - model_name_or_path, self.load_config.download_dir, - revision) + model_name_or_path, index_file, + self.load_config.download_dir, revision) hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder) + hf_weights_files, hf_folder, index_file) else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0666457756b02..075451292a8e4 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -16,7 +16,6 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import LoadConfig, ModelConfig from vllm.distributed import get_tensor_model_parallel_rank @@ -251,6 +250,7 @@ def download_weights_from_hf( def download_safetensors_index_file_from_hf( model_name_or_path: str, + index_file: str, cache_dir: Optional[str], revision: Optional[str] = None, ) -> None: @@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf( # Download the safetensors index file. hf_hub_download( repo_id=model_name_or_path, - filename=SAFE_WEIGHTS_INDEX_NAME, + filename=index_file, cache_dir=cache_dir, revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ) # If file not found on remote or locally, we should not fail since - # only some models will have SAFE_WEIGHTS_INDEX_NAME. + # only some models will have index_file. except huggingface_hub.utils.EntryNotFoundError: - logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + logger.info("No %s found in remote.", index_file) except huggingface_hub.utils.LocalEntryNotFoundError: - logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + logger.info("No %s found in local cache.", index_file) # For models like Mistral-7B-v0.3, there are both sharded # safetensors files and a consolidated safetensors file. # Passing both of these to the weight loader functionality breaks. -# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# So, we use the index_file to # look up which safetensors files should be used. def filter_duplicate_safetensors_files(hf_weights_files: List[str], - hf_folder: str) -> List[str]: + hf_folder: str, + index_file: str) -> List[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. - index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + index_file_name = os.path.join(hf_folder, index_file) if not os.path.isfile(index_file_name): return hf_weights_files # Iterate through the weight_map (weight_name: safetensors files) # to identify weights that we should use. - with open(index_file_name) as index_file: - weight_map = json.load(index_file)["weight_map"] + with open(index_file_name, "r") as f: + weight_map = json.load(f)["weight_map"] weight_files_in_index = set() for weight_name in weight_map: weight_files_in_index.add( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d59ab85041f85..dda9f7d3687f7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -383,6 +383,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), } + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm" + } def __init__( self, @@ -480,6 +499,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight) + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -557,3 +578,33 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") + + # This function is used to remap the mistral format as + # used by Mistral and Llama <=2 + def maybe_remap_mistral( + self, name: str, + loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]: + + def permute(w, n_heads): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + mapping = self.mistral_mapping + modules = name.split(".") + + # rotary embeds should be sliced + if "wk" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif "wq" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + for item in modules: + if item in mapping and mapping[item] not in name: + name = name.replace(item, mapping[item]) + + return name, loaded_weight diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4f4e79d10a677..13fcf6b918603 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,12 +1,16 @@ import contextlib +import enum +import json from pathlib import Path from typing import Any, Dict, Optional, Type, Union +from huggingface_hub import file_exists, hf_hub_download from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger @@ -27,6 +31,8 @@ else: from transformers import AutoConfig +MISTRAL_CONFIG_NAME = "params.json" + logger = init_logger(__name__) _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { @@ -53,6 +59,20 @@ AutoConfig.register(name, cls) +class ConfigFormat(str, enum.Enum): + AUTO = "auto" + HF = "hf" + MISTRAL = "mistral" + + +def file_or_path_exists(model: Union[str, Path], config_name, revision, + token) -> bool: + if Path(model).exists(): + return (Path(model) / config_name).is_file() + + return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token) + + def get_config( model: Union[str, Path], trust_remote_code: bool, @@ -60,45 +80,68 @@ def get_config( code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None, + config_format: ConfigFormat = ConfigFormat.AUTO, **kwargs, ) -> PretrainedConfig: - # Separate model folder from file path for GGUF models + is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = Path(model).name model = Path(model).parent - config_dict, _ = PretrainedConfig.get_config_dict( - model, revision=revision, code_revision=code_revision, **kwargs) + if config_format == ConfigFormat.AUTO: + if is_gguf or file_or_path_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.HF + elif file_or_path_exists(model, + MISTRAL_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.MISTRAL + else: + raise ValueError(f"No supported config format found in {model}") + + if config_format == ConfigFormat.HF: + config_dict, _ = PretrainedConfig.get_config_dict( + model, revision=revision, code_revision=code_revision, **kwargs) + + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) + else: + try: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type") - if model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained(model, - revision=revision, - code_revision=code_revision) + elif config_format == ConfigFormat.MISTRAL: + config = load_params_config(model, revision) else: - try: - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - **kwargs) - except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): - err_msg = ( - "Failed to load the model config. If the model is a custom " - "model not yet available in the HuggingFace transformers " - "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e + raise ValueError(f"Unsupported config format: {config_format}") # Special architecture mapping check for GGUF models if is_gguf: @@ -108,16 +151,70 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) - for key, value in [("rope_scaling", rope_scaling), - ("rope_theta", rope_theta)]: + for key, value in [ + ("rope_scaling", rope_scaling), + ("rope_theta", rope_theta), + ]: if value is not None: - logger.info("Updating %s from %r to %r", key, - getattr(config, key, None), value) + logger.info( + "Updating %s from %r to %r", + key, + getattr(config, key, None), + value, + ) config.update({key: value}) return config +def load_params_config(model, revision) -> PretrainedConfig: + # This function loads a params.json config which + # should be used when loading models in mistral format + + config_file_name = "params.json" + + config_path = Path(model) / config_file_name + + if not config_path.is_file(): + config_path = Path( + hf_hub_download(model, config_file_name, revision=revision)) + + with open(config_path, "r") as file: + config_dict = json.load(file) + + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + + def recurse_elems(elem: Any): + if isinstance(elem, dict): + config_dict = {} + for key, value in elem.items(): + key = config_mapping.get(key, key) + config_dict[key] = recurse_elems(value) + return PretrainedConfig(**config_dict) + else: + return elem + + config_dict["model_type"] = config_dict.get("model_type", "transformer") + config_dict["hidden_act"] = config_dict.get("activation", "silu") + config_dict["tie_word_embeddings"] = config_dict.get( + "tie_embeddings", False) + + if config_dict["model_type"] == "transformer": + if "moe" in config_dict: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + return recurse_elems(config_dict) + + def get_hf_image_processor_config( model: Union[str, Path], revision: Optional[str] = None, @@ -134,7 +231,7 @@ def get_hf_image_processor_config( def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. - No op for pure text models. + No op for pure text models. """ if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have