Skip to content

Commit 29f49cd

Browse files
[Model] Allow loading from original Mistral format (vllm-project#8168)
Co-authored-by: Michael Goin <[email protected]>
1 parent 23f3222 commit 29f49cd

File tree

7 files changed

+291
-81
lines changed

7 files changed

+291
-81
lines changed

tests/models/test_mistral.py

+40
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,43 @@ def test_models(
4141
name_0="hf",
4242
name_1="vllm",
4343
)
44+
45+
46+
@pytest.mark.parametrize("model", MODELS[1:])
47+
@pytest.mark.parametrize("dtype", ["bfloat16"])
48+
@pytest.mark.parametrize("max_tokens", [64])
49+
@pytest.mark.parametrize("num_logprobs", [5])
50+
def test_mistral_format(
51+
vllm_runner,
52+
example_prompts,
53+
model: str,
54+
dtype: str,
55+
max_tokens: int,
56+
num_logprobs: int,
57+
) -> None:
58+
with vllm_runner(
59+
model,
60+
dtype=dtype,
61+
tokenizer_mode="auto",
62+
load_format="safetensors",
63+
config_format="hf",
64+
) as hf_format_model:
65+
hf_format_outputs = hf_format_model.generate_greedy_logprobs(
66+
example_prompts, max_tokens, num_logprobs)
67+
68+
with vllm_runner(
69+
model,
70+
dtype=dtype,
71+
tokenizer_mode="mistral",
72+
load_format="mistral",
73+
config_format="mistral",
74+
) as mistral_format_model:
75+
mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
76+
example_prompts, max_tokens, num_logprobs)
77+
78+
check_logprobs_close(
79+
outputs_0_lst=hf_format_outputs,
80+
outputs_1_lst=mistral_format_outputs,
81+
name_0="hf",
82+
name_1="mistral",
83+
)

vllm/config.py

+33-29
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.model_executor.models import ModelRegistry
1414
from vllm.platforms import current_platform
1515
from vllm.tracing import is_otel_available, otel_import_error_traceback
16-
from vllm.transformers_utils.config import (get_config,
16+
from vllm.transformers_utils.config import (ConfigFormat, get_config,
1717
get_hf_image_processor_config,
1818
get_hf_text_config)
1919
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
@@ -121,35 +121,37 @@ class ModelConfig:
121121
override default neuron config that are specific to Neuron devices,
122122
this argument will be used to configure the neuron config that
123123
can not be gathered from the vllm arguments.
124+
config_format: The config format which shall be loaded.
125+
Defaults to 'auto' which defaults to 'hf'.
124126
"""
125127

126-
def __init__(
127-
self,
128-
model: str,
129-
tokenizer: str,
130-
tokenizer_mode: str,
131-
trust_remote_code: bool,
132-
dtype: Union[str, torch.dtype],
133-
seed: int,
134-
revision: Optional[str] = None,
135-
code_revision: Optional[str] = None,
136-
rope_scaling: Optional[dict] = None,
137-
rope_theta: Optional[float] = None,
138-
tokenizer_revision: Optional[str] = None,
139-
max_model_len: Optional[int] = None,
140-
spec_target_max_model_len: Optional[int] = None,
141-
quantization: Optional[str] = None,
142-
quantization_param_path: Optional[str] = None,
143-
enforce_eager: Optional[bool] = None,
144-
max_context_len_to_capture: Optional[int] = None,
145-
max_seq_len_to_capture: Optional[int] = None,
146-
max_logprobs: int = 20,
147-
disable_sliding_window: bool = False,
148-
skip_tokenizer_init: bool = False,
149-
served_model_name: Optional[Union[str, List[str]]] = None,
150-
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
151-
use_async_output_proc: bool = True,
152-
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
128+
def __init__(self,
129+
model: str,
130+
tokenizer: str,
131+
tokenizer_mode: str,
132+
trust_remote_code: bool,
133+
dtype: Union[str, torch.dtype],
134+
seed: int,
135+
revision: Optional[str] = None,
136+
code_revision: Optional[str] = None,
137+
rope_scaling: Optional[dict] = None,
138+
rope_theta: Optional[float] = None,
139+
tokenizer_revision: Optional[str] = None,
140+
max_model_len: Optional[int] = None,
141+
spec_target_max_model_len: Optional[int] = None,
142+
quantization: Optional[str] = None,
143+
quantization_param_path: Optional[str] = None,
144+
enforce_eager: Optional[bool] = None,
145+
max_context_len_to_capture: Optional[int] = None,
146+
max_seq_len_to_capture: Optional[int] = None,
147+
max_logprobs: int = 20,
148+
disable_sliding_window: bool = False,
149+
skip_tokenizer_init: bool = False,
150+
served_model_name: Optional[Union[str, List[str]]] = None,
151+
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
152+
use_async_output_proc: bool = True,
153+
override_neuron_config: Optional[Dict[str, Any]] = None,
154+
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
153155
self.model = model
154156
self.tokenizer = tokenizer
155157
self.tokenizer_mode = tokenizer_mode
@@ -176,7 +178,8 @@ def __init__(
176178
self.skip_tokenizer_init = skip_tokenizer_init
177179

178180
self.hf_config = get_config(self.model, trust_remote_code, revision,
179-
code_revision, rope_scaling, rope_theta)
181+
code_revision, rope_scaling, rope_theta,
182+
config_format)
180183
self.hf_text_config = get_hf_text_config(self.hf_config)
181184
self.hf_image_processor_config = get_hf_image_processor_config(
182185
self.model, revision)
@@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
746749
SHARDED_STATE = "sharded_state"
747750
GGUF = "gguf"
748751
BITSANDBYTES = "bitsandbytes"
752+
MISTRAL = "mistral"
749753

750754

751755
@dataclass

vllm/engine/arg_utils.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import torch
99

1010
import vllm.envs as envs
11-
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
12-
EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
13-
ModelConfig, ObservabilityConfig, ParallelConfig,
14-
PromptAdapterConfig, SchedulerConfig,
11+
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
12+
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
13+
LoRAConfig, ModelConfig, ObservabilityConfig,
14+
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
1515
SpeculativeConfig, TokenizerPoolConfig)
1616
from vllm.executor.executor_base import ExecutorBase
1717
from vllm.logger import init_logger
@@ -65,6 +65,7 @@ class EngineArgs:
6565
trust_remote_code: bool = False
6666
download_dir: Optional[str] = None
6767
load_format: str = 'auto'
68+
config_format: str = 'auto'
6869
dtype: str = 'auto'
6970
kv_cache_dtype: str = 'auto'
7071
quantization_param_path: Optional[str] = None
@@ -234,6 +235,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
234235
'section for more information.\n'
235236
'* "bitsandbytes" will load the weights using bitsandbytes '
236237
'quantization.\n')
238+
parser.add_argument(
239+
'--config-format',
240+
default=EngineArgs.config_format,
241+
choices=[f.value for f in ConfigFormat],
242+
help='The format of the model config to load.\n\n'
243+
'* "auto" will try to load the config in hf format '
244+
'if available else it will try to load in mistral format ')
237245
parser.add_argument(
238246
'--dtype',
239247
type=str,
@@ -813,7 +821,10 @@ def create_engine_config(self) -> EngineConfig:
813821
served_model_name=self.served_model_name,
814822
limit_mm_per_prompt=self.limit_mm_per_prompt,
815823
use_async_output_proc=not self.disable_async_output_proc,
816-
override_neuron_config=self.override_neuron_config)
824+
override_neuron_config=self.override_neuron_config,
825+
config_format=self.config_format,
826+
)
827+
817828
cache_config = CacheConfig(
818829
block_size=self.block_size if self.device != "neuron" else
819830
self.max_model_len, # neuron needs block_size = max_model_len

vllm/model_executor/model_loader/loader.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from huggingface_hub import HfApi, hf_hub_download
1818
from torch import nn
1919
from transformers import AutoModelForCausalLM, PretrainedConfig
20+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
2021

2122
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
2223
LoRAConfig, ModelConfig, MultiModalConfig,
@@ -241,12 +242,17 @@ def _prepare_weights(self, model_name_or_path: str,
241242
is_local = os.path.isdir(model_name_or_path)
242243
load_format = self.load_config.load_format
243244
use_safetensors = False
245+
index_file = SAFE_WEIGHTS_INDEX_NAME
244246
# Some quantized models use .pt files for storing the weights.
245247
if load_format == LoadFormat.AUTO:
246248
allow_patterns = ["*.safetensors", "*.bin"]
247249
elif load_format == LoadFormat.SAFETENSORS:
248250
use_safetensors = True
249251
allow_patterns = ["*.safetensors"]
252+
elif load_format == LoadFormat.MISTRAL:
253+
use_safetensors = True
254+
allow_patterns = ["consolidated*.safetensors"]
255+
index_file = "consolidated.safetensors.index.json"
250256
elif load_format == LoadFormat.PT:
251257
allow_patterns = ["*.pt"]
252258
elif load_format == LoadFormat.NPCACHE:
@@ -284,10 +290,10 @@ def _prepare_weights(self, model_name_or_path: str,
284290
# any files not found in the index.
285291
if not is_local:
286292
download_safetensors_index_file_from_hf(
287-
model_name_or_path, self.load_config.download_dir,
288-
revision)
293+
model_name_or_path, index_file,
294+
self.load_config.download_dir, revision)
289295
hf_weights_files = filter_duplicate_safetensors_files(
290-
hf_weights_files, hf_folder)
296+
hf_weights_files, hf_folder, index_file)
291297
else:
292298
hf_weights_files = filter_files_not_needed_for_inference(
293299
hf_weights_files)

vllm/model_executor/model_loader/weight_utils.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
1717
from safetensors.torch import load_file, safe_open, save_file
1818
from tqdm.auto import tqdm
19-
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
2019

2120
from vllm.config import LoadConfig, ModelConfig
2221
from vllm.distributed import get_tensor_model_parallel_rank
@@ -251,6 +250,7 @@ def download_weights_from_hf(
251250

252251
def download_safetensors_index_file_from_hf(
253252
model_name_or_path: str,
253+
index_file: str,
254254
cache_dir: Optional[str],
255255
revision: Optional[str] = None,
256256
) -> None:
@@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
269269
# Download the safetensors index file.
270270
hf_hub_download(
271271
repo_id=model_name_or_path,
272-
filename=SAFE_WEIGHTS_INDEX_NAME,
272+
filename=index_file,
273273
cache_dir=cache_dir,
274274
revision=revision,
275275
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
276276
)
277277
# If file not found on remote or locally, we should not fail since
278-
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
278+
# only some models will have index_file.
279279
except huggingface_hub.utils.EntryNotFoundError:
280-
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
280+
logger.info("No %s found in remote.", index_file)
281281
except huggingface_hub.utils.LocalEntryNotFoundError:
282-
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
282+
logger.info("No %s found in local cache.", index_file)
283283

284284

285285
# For models like Mistral-7B-v0.3, there are both sharded
286286
# safetensors files and a consolidated safetensors file.
287287
# Passing both of these to the weight loader functionality breaks.
288-
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
288+
# So, we use the index_file to
289289
# look up which safetensors files should be used.
290290
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
291-
hf_folder: str) -> List[str]:
291+
hf_folder: str,
292+
index_file: str) -> List[str]:
292293
# model.safetensors.index.json is a mapping from keys in the
293294
# torch state_dict to safetensors file holding that weight.
294-
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
295+
index_file_name = os.path.join(hf_folder, index_file)
295296
if not os.path.isfile(index_file_name):
296297
return hf_weights_files
297298

298299
# Iterate through the weight_map (weight_name: safetensors files)
299300
# to identify weights that we should use.
300-
with open(index_file_name) as index_file:
301-
weight_map = json.load(index_file)["weight_map"]
301+
with open(index_file_name, "r") as f:
302+
weight_map = json.load(f)["weight_map"]
302303
weight_files_in_index = set()
303304
for weight_name in weight_map:
304305
weight_files_in_index.add(

vllm/model_executor/models/llama.py

+51
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
375375
"gate_proj": ("gate_up_proj", 0),
376376
"up_proj": ("gate_up_proj", 1),
377377
}
378+
# Mistral/Llama models can also be loaded with --load-format mistral
379+
# from consolidated.safetensors checkpoints
380+
mistral_mapping = {
381+
"layers": "model.layers",
382+
"attention": "self_attn",
383+
"wq": "q_proj",
384+
"wk": "k_proj",
385+
"wv": "v_proj",
386+
"wo": "o_proj",
387+
"attention_norm": "input_layernorm",
388+
"feed_forward": "mlp",
389+
"w1": "gate_proj",
390+
"w2": "down_proj",
391+
"w3": "up_proj",
392+
"ffn_norm": "post_attention_layernorm",
393+
"tok_embeddings": "model.embed_tokens",
394+
"output": "lm_head",
395+
"norm": "model.norm"
396+
}
378397

379398
def __init__(
380399
self,
@@ -472,6 +491,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
472491
]
473492
params_dict = dict(self.named_parameters())
474493
for name, loaded_weight in weights:
494+
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
495+
475496
if "rotary_emb.inv_freq" in name:
476497
continue
477498
if ("rotary_emb.cos_cached" in name
@@ -549,3 +570,33 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
549570
else:
550571
raise RuntimeError("Self attention has no KV cache scaling "
551572
"factor attribute!")
573+
574+
# This function is used to remap the mistral format as
575+
# used by Mistral and Llama <=2
576+
def maybe_remap_mistral(
577+
self, name: str,
578+
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
579+
580+
def permute(w, n_heads):
581+
attn_in = self.config.head_dim * n_heads
582+
attn_out = self.config.hidden_size
583+
584+
return w.view(n_heads, attn_in // n_heads // 2, 2,
585+
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
586+
587+
mapping = self.mistral_mapping
588+
modules = name.split(".")
589+
590+
# rotary embeds should be sliced
591+
if "wk" in modules:
592+
loaded_weight = permute(loaded_weight,
593+
self.config.num_key_value_heads)
594+
elif "wq" in modules:
595+
loaded_weight = permute(loaded_weight,
596+
self.config.num_attention_heads)
597+
598+
for item in modules:
599+
if item in mapping and mapping[item] not in name:
600+
name = name.replace(item, mapping[item])
601+
602+
return name, loaded_weight

0 commit comments

Comments
 (0)