Skip to content

Commit baa5467

Browse files
authored
[BugFix] Fix Granite model configuration (vllm-project#8216)
1 parent db3bf7c commit baa5467

File tree

2 files changed

+42
-24
lines changed

2 files changed

+42
-24
lines changed

vllm/transformers_utils/config.py

+38-24
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010

1111
from vllm.envs import VLLM_USE_MODELSCOPE
1212
from vllm.logger import init_logger
13+
# yapf conflicts with isort for this block
14+
# yapf: disable
1315
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
1416
EAGLEConfig, ExaoneConfig,
15-
InternVLChatConfig, JAISConfig,
16-
MedusaConfig, MLPSpeculatorConfig,
17-
MPTConfig, NemotronConfig,
18-
RWConfig, UltravoxConfig)
17+
GraniteConfig, InternVLChatConfig,
18+
JAISConfig, MedusaConfig,
19+
MLPSpeculatorConfig, MPTConfig,
20+
NemotronConfig, RWConfig,
21+
UltravoxConfig)
22+
# yapf: enable
1923
from vllm.transformers_utils.utils import check_gguf_file
2024

2125
if VLLM_USE_MODELSCOPE:
@@ -39,6 +43,9 @@
3943
"internvl_chat": InternVLChatConfig,
4044
"nemotron": NemotronConfig,
4145
"ultravox": UltravoxConfig,
46+
# Granite can be removed from here once we have upgraded to
47+
# transformers 4.45+
48+
"granite": GraniteConfig,
4249
}
4350

4451
for name, cls in _CONFIG_REGISTRY.items():
@@ -62,29 +69,36 @@ def get_config(
6269
kwargs["gguf_file"] = Path(model).name
6370
model = Path(model).parent
6471

65-
try:
66-
config = AutoConfig.from_pretrained(
67-
model,
68-
trust_remote_code=trust_remote_code,
69-
revision=revision,
70-
code_revision=code_revision,
71-
**kwargs)
72-
except ValueError as e:
73-
if (not trust_remote_code and
74-
"requires you to execute the configuration file" in str(e)):
75-
err_msg = (
76-
"Failed to load the model config. If the model is a custom "
77-
"model not yet available in the HuggingFace transformers "
78-
"library, consider setting `trust_remote_code=True` in LLM "
79-
"or using the `--trust-remote-code` flag in the CLI.")
80-
raise RuntimeError(err_msg) from e
81-
else:
82-
raise e
83-
if config.model_type in _CONFIG_REGISTRY:
84-
config_class = _CONFIG_REGISTRY[config.model_type]
72+
config_dict, _ = PretrainedConfig.get_config_dict(
73+
model, revision=revision, code_revision=code_revision, **kwargs)
74+
75+
# Use custom model class if it's in our registry
76+
model_type = config_dict.get("model_type")
77+
if model_type in _CONFIG_REGISTRY:
78+
config_class = _CONFIG_REGISTRY[model_type]
8579
config = config_class.from_pretrained(model,
8680
revision=revision,
8781
code_revision=code_revision)
82+
else:
83+
try:
84+
config = AutoConfig.from_pretrained(
85+
model,
86+
trust_remote_code=trust_remote_code,
87+
revision=revision,
88+
code_revision=code_revision,
89+
**kwargs)
90+
except ValueError as e:
91+
if (not trust_remote_code
92+
and "requires you to execute the configuration file"
93+
in str(e)):
94+
err_msg = (
95+
"Failed to load the model config. If the model is a custom "
96+
"model not yet available in the HuggingFace transformers "
97+
"library, consider setting `trust_remote_code=True` in LLM "
98+
"or using the `--trust-remote-code` flag in the CLI.")
99+
raise RuntimeError(err_msg) from e
100+
else:
101+
raise e
88102

89103
# Special architecture mapping check for GGUF models
90104
if is_gguf:

vllm/transformers_utils/configs/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
77
# `FalconConfig` class from the official HuggingFace transformers library.
88
from vllm.transformers_utils.configs.falcon import RWConfig
9+
from vllm.transformers_utils.configs.granite import GraniteConfig
910
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
1011
from vllm.transformers_utils.configs.jais import JAISConfig
1112
from vllm.transformers_utils.configs.medusa import MedusaConfig
@@ -27,4 +28,7 @@
2728
"MLPSpeculatorConfig",
2829
"NemotronConfig",
2930
"UltravoxConfig",
31+
# Granite can be removed from here once we have upgraded to
32+
# transformers 4.45+
33+
"GraniteConfig",
3034
]

0 commit comments

Comments
 (0)