From 039c1aef6c1d2e616266888c24d6a38e1595bff7 Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Thu, 26 Mar 2026 05:05:46 -0700 Subject: [PATCH] Revert "Various Transformers v5 fixes (#38127)" This reverts commit 3c3c084240261a012531ab990be1d13f1e651c26. --- .../offline_mode/test_offline_mode.py | 1 - vllm/config/model.py | 8 -- vllm/model_executor/models/olmo2.py | 3 +- vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + .../configs/deepseek_vl2.py | 3 - vllm/transformers_utils/configs/olmo3.py | 83 +++++++++++++++++++ 7 files changed, 88 insertions(+), 13 deletions(-) create mode 100644 vllm/transformers_utils/configs/olmo3.py diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index 0708597079fc..8ca15c286c75 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -112,7 +112,6 @@ def _re_import_modules(): aliased_module_patterns = [ r".+\.tokenization_utils$", r".+\.tokenization_utils_fast$", - r".+\.image_processing_utils_fast$", r".+\.models\..+\.image_processing_.+_fast$", ] diff --git a/vllm/config/model.py b/vllm/config/model.py index e934ff554437..e51723009618 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -586,14 +586,6 @@ def __post_init__( config_format=self.config_format, ) - # Some checkpoints set sliding_window to 0 to indicate that sliding window is - # disabled, but vLLM uses None for that. Convert 0 to None to avoid errors. - # Set before get_and_verify_max_len to ensure that max_model_len does not get - # capped to 0. - if self.get_sliding_window() == 0: - self.disable_sliding_window = True - self.hf_text_config.sliding_window = None - self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 212140fe15ea..250c3892acb4 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -30,7 +30,7 @@ import torch from torch import nn -from transformers import Olmo2Config, Olmo3Config +from transformers import Olmo2Config from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -63,6 +63,7 @@ maybe_prefix, ) from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.olmo3 import Olmo3Config class Olmo2Attention(nn.Module): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 7836a44e73ea..9894a6a88e16 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -105,6 +105,7 @@ def __getitem__(self, key): eagle="EAGLEConfig", speculators="SpeculatorsConfig", nemotron="NemotronConfig", + olmo3="Olmo3Config", olmo_hybrid="OlmoHybridConfig", ovis="OvisConfig", ultravox="UltravoxConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 75bfda3fbdfe..4364829d9ef5 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -50,6 +50,7 @@ "KimiK25Config": "vllm.transformers_utils.configs.kimi_k25", "NemotronConfig": "vllm.transformers_utils.configs.nemotron", "NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h", + "Olmo3Config": "vllm.transformers_utils.configs.olmo3", "OlmoHybridConfig": "vllm.transformers_utils.configs.olmo_hybrid", "OvisConfig": "vllm.transformers_utils.configs.ovis", "PixelShuffleSiglip2VisionConfig": "vllm.transformers_utils.configs.isaac", @@ -105,6 +106,7 @@ "KimiK25Config", "NemotronConfig", "NemotronHConfig", + "Olmo3Config", "OlmoHybridConfig", "OvisConfig", "PixelShuffleSiglip2VisionConfig", diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py index 9c816488c087..80fedd1017ca 100644 --- a/vllm/transformers_utils/configs/deepseek_vl2.py +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -114,9 +114,6 @@ def __init__( self.projector_config = MlpProjectorConfig(**projector_config) language_config = kwargs.get("language_config", {}) - # remove kv_lora_rank if not specified, passing None is prohibited - if language_config.get("kv_lora_rank") is None: - language_config.pop("kv_lora_rank", None) self.text_config = DeepseekV2Config(**language_config) self.tile_tag = tile_tag diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py new file mode 100644 index 000000000000..c4691b661af3 --- /dev/null +++ b/vllm/transformers_utils/configs/olmo3.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class Olmo3Config(PretrainedConfig): + model_type = "olmo3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_parameters=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + sliding_window=4096, + layer_types=None, + **kwargs, + ): + # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM + # in vLLM. + if "architectures" not in kwargs: + kwargs["architectures"] = ["Olmo2ForCausalLM"] + elif "Olmo3ForCausalLM" in kwargs["architectures"]: + kwargs["architectures"].remove("Olmo3ForCausalLM") + kwargs["architectures"].append("Olmo2ForCausalLM") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"} + rope_theta = kwargs.pop("rope_theta", 10000.0) + if "rope_theta" not in rope_parameters: + rope_parameters["rope_theta"] = rope_theta + self.rope_parameters = rope_parameters + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" + for i in range(self.num_hidden_layers) + ]