Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,12 @@ def __post_init__(self) -> None:
self.config_format = ConfigFormat(self.config_format)

hf_config = get_config(self.hf_config_path or self.model,
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)

if hf_overrides_kw:
logger.debug("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.debug("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config)

self.trust_remote_code,
self.revision,
self.code_revision,
self.config_format,
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn)
self.hf_config = hf_config

self.hf_text_config = get_hf_text_config(self.hf_config)
Expand Down Expand Up @@ -4988,4 +4984,4 @@ class SpeechToTextConfig:

@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None
return self.min_energy_split_window_size is not None
10 changes: 10 additions & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def get_config(
revision: Optional[str] = None,
code_revision: Optional[str] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides_kw: Optional[dict[str, Any]] = None,
hf_overrides_fn: Optional[Callable[[PretrainedConfig],
PretrainedConfig]] = None,
**kwargs,
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
Expand Down Expand Up @@ -423,6 +426,13 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

if hf_overrides_kw:
logger.debug("Overriding HF config with %s", hf_overrides_kw)
config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.debug("Overriding HF config with %s", hf_overrides_fn)
config = hf_overrides_fn(config)

patch_rope_scaling(config)

if trust_remote_code:
Expand Down