Skip to content
19 changes: 18 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def __post_init__(

# Avoid running try_verify_and_update_config multiple times
self.config_updated = False

self._try_verify_and_update_model_config()
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
Expand Down Expand Up @@ -1008,6 +1008,23 @@ def _verify_with_expert_parallelism(self) -> None:
"when expert parallelism is enabled."
)

def _try_verify_and_update_model_config(self):
# Avoid running try_verify_and_update_config multiple times
if getattr(self, "config_updated", False):
return

architecture = self.architecture
if architecture is None:
return

from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP,
)

cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_model_config(self)

def verify_dual_chunk_attention_config(
self,
load_config: LoadConfig,
Expand Down
55 changes: 29 additions & 26 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,32 @@
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig

logger = init_logger(__name__)


class VerifyAndUpdateConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
raise NotImplementedError
return

@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
return


class Gemma3TextModelConfig:
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
hf_config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
hf_config = model_config.hf_config
hf_config.is_causal = not hf_config.use_bidirectional_attention


class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config

assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
Expand All @@ -53,16 +57,15 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:

class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.use_activation is None:
pooler_config.use_activation = False


class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
model_config = vllm_config.model_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config

if config.position_embedding_type == "rotary":
Expand Down Expand Up @@ -90,10 +93,10 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:

class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import PoolingTypeStr

hf_config = vllm_config.model_config.hf_config
hf_config = model_config.hf_config
hf_config.is_causal = False

pooling_type_map: dict[str, PoolingTypeStr] = {
Expand All @@ -105,7 +108,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling} not supported")
vllm_config.model_config.pooler_config.pooling_type = pooling_type
model_config.pooler_config.pooling_type = pooling_type


class NomicBertModelConfig(VerifyAndUpdateConfig):
Expand Down Expand Up @@ -204,26 +207,26 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:

class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config

if pooler_config.step_tag_id is None:
pooler_config.step_tag_id = 151651


class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config

if pooler_config.softmax is None:
pooler_config.softmax = False


class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config

is_original_qwen3_reranker = getattr(
config, "is_original_qwen3_reranker", False
Expand All @@ -237,23 +240,23 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
)
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
model_config.hf_config.method = "from_2_way_softmax"


class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
config.num_labels = 1
pooler_config = vllm_config.model_config.pooler_config
pooler_config = model_config.pooler_config
if pooler_config.logit_bias is None:
pooler_config.logit_bias = 2.65


class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config

assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
SupportsLoRA,
SupportsPP,
)
from .interfaces_base import attn_type
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
Expand Down Expand Up @@ -707,14 +706,12 @@ def permute(w: torch.Tensor, n_heads: int, attn_out: int):
return name, loaded_weight


@attn_type("encoder_only")
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
# This class sets the correct attention type and pooling type
# through LlamaBidirectionalConfig.
pass


@attn_type("encoder_only")
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
# This class sets the correct attention type and pooling type
# through LlamaBidirectionalConfig.
Expand Down
Loading