Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Allow loading from original Mistral format #8168

Merged
merged 33 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a36ff88
WIP
patrickvonplaten Sep 4, 2024
a5472c5
WIP
patrickvonplaten Sep 4, 2024
ab85a2e
up
patrickvonplaten Sep 4, 2024
e0264e0
up
patrickvonplaten Sep 4, 2024
0c1b115
up
patrickvonplaten Sep 4, 2024
fa44979
up
patrickvonplaten Sep 4, 2024
2c0bc8c
up
patrickvonplaten Sep 4, 2024
e33e127
Up
patrickvonplaten Sep 4, 2024
ffdcdaf
WIP
patrickvonplaten Sep 4, 2024
f13b059
up
patrickvonplaten Sep 4, 2024
f4557bb
Merge branch 'main' into add_mistral_model_format
patrickvonplaten Sep 4, 2024
0fbf430
up
patrickvonplaten Sep 4, 2024
8eaac25
up
patrickvonplaten Sep 5, 2024
a8091a6
up
patrickvonplaten Sep 5, 2024
b9ad975
Up
patrickvonplaten Sep 5, 2024
928e47a
Up
patrickvonplaten Sep 5, 2024
990ee19
WIP
patrickvonplaten Sep 5, 2024
4a1044b
WIP
patrickvonplaten Sep 5, 2024
b3814d5
up
patrickvonplaten Sep 5, 2024
bbfb843
up
patrickvonplaten Sep 5, 2024
302397f
up
patrickvonplaten Sep 5, 2024
1fc792e
Up
patrickvonplaten Sep 5, 2024
0428eaa
up
patrickvonplaten Sep 5, 2024
6efb51c
Update vllm/config.py
patrickvonplaten Sep 5, 2024
5f05f1e
Merge branch 'main' into add_mistral_model_format
patrickvonplaten Sep 6, 2024
402d77c
Up
patrickvonplaten Sep 6, 2024
a79832d
Up
patrickvonplaten Sep 6, 2024
d46c359
up
patrickvonplaten Sep 6, 2024
5f41771
WIP
patrickvonplaten Sep 6, 2024
516b84a
up
patrickvonplaten Sep 6, 2024
43d051c
WIP
patrickvonplaten Sep 6, 2024
c947a16
Merge branch 'add_mistral_model_format' of https://github.com/patrick…
patrickvonplaten Sep 6, 2024
7679950
up
patrickvonplaten Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,40 @@ def test_models(
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS[1:])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new test that checks that all tokenizer_mode, load_format and config_format in mistral matches HF format

vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model,
dtype=dtype,
tokenizer_mode="auto",
load_format="safetensors",
config_format="hf") as hf_format_model:
hf_format_outputs = hf_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)


with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
load_format="mistral",
config_format="mistral") as mistral_format_model:
mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_format_outputs,
outputs_1_lst=mistral_format_outputs,
name_0="hf",
name_1="mistral",
)
10 changes: 5 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (get_config,
get_hf_image_processor_config,
get_hf_text_config)
get_hf_text_config, ConfigFormat)
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_xpu,
Expand Down Expand Up @@ -119,8 +119,8 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
load_params_config: Load the config from mistral format
(params.json) instead of config.json.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults toh 'hf'.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self,
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(self,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
load_params_config: bool = False) -> None:
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self,

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta,
load_params_config)
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
Expand Down
12 changes: 10 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
SpeculativeConfig, TokenizerPoolConfig, ConfigFormat)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
Expand Down Expand Up @@ -65,6 +65,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: str = 'auto'
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -234,6 +235,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
choices=[f.value for f in ConfigFormat],
help='The format of the model config to load.\n\n'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format ')
parser.add_argument(
'--dtype',
type=str,
Expand Down Expand Up @@ -814,7 +822,7 @@ def create_engine_config(self) -> EngineConfig:
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
load_params_config=self.load_format == "mistral",
config_format=self.config_format,
)

cache_config = CacheConfig(
Expand Down
26 changes: 22 additions & 4 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import enum
import json
from pathlib import Path
from huggingface_hub import file_exists
from typing import Any, Dict, Optional, Type, Union

from huggingface_hub import hf_hub_download
from torch import Value
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
Expand Down Expand Up @@ -48,14 +51,21 @@
AutoConfig.register(name, cls)


class ConfigFormat(str, enum.Enum):
AUTO = "auto"
HF = "hf"
MISTRAL = "mistral"



def get_config(
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
load_from_params: bool = False,
config_format: ConfigFormat = ConfigFormat.AUTO,
**kwargs,
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
Expand All @@ -65,16 +75,24 @@ def get_config(
kwargs["gguf_file"] = Path(model).name
model = Path(model).parent

try:
if load_from_params:
config = load_params_config(model, revision)
if config_format == ConfigFormat.AUTO:
if file_exists(model, "config.json", revision=revision, token=kwargs.get("token")):
config_format = ConfigFormat.HF
else:
config_format = ConfigFormat.MISTRAL
mgoin marked this conversation as resolved.
Show resolved Hide resolved

try:
if config_format == ConfigFormat.HF:
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
**kwargs)
elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision)
else:
raise ValueError(f"Unsupported config format: {config_format}")
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
Expand Down
Loading