Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Allow loading from original Mistral format #8168

Merged
merged 33 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a36ff88
WIP
patrickvonplaten Sep 4, 2024
a5472c5
WIP
patrickvonplaten Sep 4, 2024
ab85a2e
up
patrickvonplaten Sep 4, 2024
e0264e0
up
patrickvonplaten Sep 4, 2024
0c1b115
up
patrickvonplaten Sep 4, 2024
fa44979
up
patrickvonplaten Sep 4, 2024
2c0bc8c
up
patrickvonplaten Sep 4, 2024
e33e127
Up
patrickvonplaten Sep 4, 2024
ffdcdaf
WIP
patrickvonplaten Sep 4, 2024
f13b059
up
patrickvonplaten Sep 4, 2024
f4557bb
Merge branch 'main' into add_mistral_model_format
patrickvonplaten Sep 4, 2024
0fbf430
up
patrickvonplaten Sep 4, 2024
8eaac25
up
patrickvonplaten Sep 5, 2024
a8091a6
up
patrickvonplaten Sep 5, 2024
b9ad975
Up
patrickvonplaten Sep 5, 2024
928e47a
Up
patrickvonplaten Sep 5, 2024
990ee19
WIP
patrickvonplaten Sep 5, 2024
4a1044b
WIP
patrickvonplaten Sep 5, 2024
b3814d5
up
patrickvonplaten Sep 5, 2024
bbfb843
up
patrickvonplaten Sep 5, 2024
302397f
up
patrickvonplaten Sep 5, 2024
1fc792e
Up
patrickvonplaten Sep 5, 2024
0428eaa
up
patrickvonplaten Sep 5, 2024
6efb51c
Update vllm/config.py
patrickvonplaten Sep 5, 2024
5f05f1e
Merge branch 'main' into add_mistral_model_format
patrickvonplaten Sep 6, 2024
402d77c
Up
patrickvonplaten Sep 6, 2024
a79832d
Up
patrickvonplaten Sep 6, 2024
d46c359
up
patrickvonplaten Sep 6, 2024
5f41771
WIP
patrickvonplaten Sep 6, 2024
516b84a
up
patrickvonplaten Sep 6, 2024
43d051c
WIP
patrickvonplaten Sep 6, 2024
c947a16
Merge branch 'add_mistral_model_format' of https://github.com/patrick…
patrickvonplaten Sep 6, 2024
7679950
up
patrickvonplaten Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,43 @@ def test_models(
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS[1:])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new test that checks that all tokenizer_mode, load_format and config_format in mistral matches HF format

vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="auto",
load_format="safetensors",
config_format="hf",
) as hf_format_model:
hf_format_outputs = hf_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="mistral",
load_format="mistral",
config_format="mistral",
) as mistral_format_model:
mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_format_outputs,
outputs_1_lst=mistral_format_outputs,
name_0="hf",
name_1="mistral",
)
62 changes: 33 additions & 29 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (get_config,
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
Expand Down Expand Up @@ -121,35 +121,37 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
"""

def __init__(
self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
def __init__(self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand All @@ -176,7 +178,8 @@ def __init__(
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta)
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
Expand Down Expand Up @@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"


@dataclass
Expand Down
21 changes: 16 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import torch

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
Expand Down Expand Up @@ -65,6 +65,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -234,6 +235,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
choices=[f.value for f in ConfigFormat],
help='The format of the model config to load.\n\n'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format ')
parser.add_argument(
'--dtype',
type=str,
Expand Down Expand Up @@ -813,7 +821,10 @@ def create_engine_config(self) -> EngineConfig:
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config)
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
)

cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
Expand Down
12 changes: 9 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
Expand Down Expand Up @@ -241,12 +242,17 @@ def _prepare_weights(self, model_name_or_path: str,
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
Expand Down Expand Up @@ -284,10 +290,10 @@ def _prepare_weights(self, model_name_or_path: str,
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path, self.load_config.download_dir,
revision)
model_name_or_path, index_file,
self.load_config.download_dir, revision)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder)
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
Expand Down
21 changes: 11 additions & 10 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
Expand Down Expand Up @@ -251,6 +250,7 @@ def download_weights_from_hf(

def download_safetensors_index_file_from_hf(
model_name_or_path: str,
index_file: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
Expand All @@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
filename=index_file,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
# only some models will have index_file.
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
logger.info("No %s found in remote.", index_file)
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
logger.info("No %s found in local cache.", index_file)


# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
hf_folder: str) -> List[str]:
hf_folder: str,
index_file: str) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
return hf_weights_files

# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as index_file:
weight_map = json.load(index_file)["weight_map"]
with open(index_file_name, "r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
Expand Down
51 changes: 51 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}

def __init__(
self,
Expand Down Expand Up @@ -472,6 +491,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)

if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
Expand Down Expand Up @@ -549,3 +570,33 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self, name: str,
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:

def permute(w, n_heads):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size

return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)

mapping = self.mistral_mapping
modules = name.split(".")

# rotary embeds should be sliced
if "wk" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)

for item in modules:
if item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])

return name, loaded_weight
Loading
Loading