Skip to content

Commit

Permalink
[Model] Allow loading from original Mistral format (vllm-project#8168)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
2 people authored and Alvant committed Oct 26, 2024
1 parent efd0301 commit 0bc27fe
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 81 deletions.
40 changes: 40 additions & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,43 @@ def test_models(
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS[1:])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="auto",
load_format="safetensors",
config_format="hf",
) as hf_format_model:
hf_format_outputs = hf_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="mistral",
load_format="mistral",
config_format="mistral",
) as mistral_format_model:
mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_format_outputs,
outputs_1_lst=mistral_format_outputs,
name_0="hf",
name_1="mistral",
)
62 changes: 33 additions & 29 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (get_config,
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
Expand Down Expand Up @@ -121,35 +121,37 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
"""

def __init__(
self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
def __init__(self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand All @@ -176,7 +178,8 @@ def __init__(
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta)
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
Expand Down Expand Up @@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"


@dataclass
Expand Down
21 changes: 16 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import torch

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
Expand Down Expand Up @@ -65,6 +65,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -234,6 +235,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
choices=[f.value for f in ConfigFormat],
help='The format of the model config to load.\n\n'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format ')
parser.add_argument(
'--dtype',
type=str,
Expand Down Expand Up @@ -813,7 +821,10 @@ def create_engine_config(self) -> EngineConfig:
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config)
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
)

cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
Expand Down
12 changes: 9 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
Expand Down Expand Up @@ -241,12 +242,17 @@ def _prepare_weights(self, model_name_or_path: str,
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
Expand Down Expand Up @@ -284,10 +290,10 @@ def _prepare_weights(self, model_name_or_path: str,
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path, self.load_config.download_dir,
revision)
model_name_or_path, index_file,
self.load_config.download_dir, revision)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder)
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
Expand Down
21 changes: 11 additions & 10 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
Expand Down Expand Up @@ -251,6 +250,7 @@ def download_weights_from_hf(

def download_safetensors_index_file_from_hf(
model_name_or_path: str,
index_file: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
Expand All @@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
filename=index_file,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
# only some models will have index_file.
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
logger.info("No %s found in remote.", index_file)
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
logger.info("No %s found in local cache.", index_file)


# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
hf_folder: str) -> List[str]:
hf_folder: str,
index_file: str) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
return hf_weights_files

# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as index_file:
weight_map = json.load(index_file)["weight_map"]
with open(index_file_name, "r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
Expand Down
51 changes: 51 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}

def __init__(
self,
Expand Down Expand Up @@ -472,6 +491,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)

if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
Expand Down Expand Up @@ -549,3 +570,33 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self, name: str,
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:

def permute(w, n_heads):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size

return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)

mapping = self.mistral_mapping
modules = name.split(".")

# rotary embeds should be sliced
if "wk" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)

for item in modules:
if item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])

return name, loaded_weight
Loading

0 comments on commit 0bc27fe

Please sign in to comment.