From f46685f4e23eeadda1fec3f9b11734634847a167 Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 9 Sep 2025 17:36:32 +0000 Subject: [PATCH 1/7] Copy Olmo2 model as Olmo3 Signed-off-by: Shane A --- vllm/model_executor/models/olmo3.py | 428 ++++++++++++++++++++++++++++ 1 file changed, 428 insertions(+) create mode 100644 vllm/model_executor/models/olmo3.py diff --git a/vllm/model_executor/models/olmo3.py b/vllm/model_executor/models/olmo3.py new file mode 100644 index 000000000000..bccd1b87043a --- /dev/null +++ b/vllm/model_executor/models/olmo3.py @@ -0,0 +1,428 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo2 model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from functools import partial +from itertools import islice +from typing import Optional, Union + +import torch +from torch import nn +from transformers import Olmo2Config + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.distributed.utils import split_tensor_along_last_dim +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + + +class Olmo2Attention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + assert isinstance(self.config, Olmo2Config) + + hidden_size = self.config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = self.config.num_attention_heads + + assert hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = (self.config.num_key_value_heads + or self.total_num_heads) + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, + eps=self.config.rms_norm_eps, + ) + self.q_norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=prefix, + ) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.o_proj", + ) + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm(q) + k = self.k_norm(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Olmo2MLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Olmo2DecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + # Attention block. + self.self_attn = Olmo2Attention(vllm_config=vllm_config, + prefix=f"{prefix}.self_attn") + + # MLP block. + self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") + + # LayerNorm + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Attention block. + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@support_torch_compile +class Olmo2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + assert isinstance(self.config, Olmo2Config) + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + self.config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + else: + hidden_states = self.embed_tokens(input_ids) + + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + assert isinstance(hidden_states, torch.Tensor) + + # Apply blocks one-by-one. + for layer in islice(self.layers, self.start_layer, self.end_layer): + # shape: (batch_size, seq_len, d_model) + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader # type: ignore + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + """ + Extremely barebones HF model wrapper. + """ + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + self.config = config + self.model = Olmo2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) From 5f1e4073bb53bb084e2b78f9d819d199bc92c752 Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 9 Sep 2025 18:32:25 +0000 Subject: [PATCH 2/7] Add Olmo3Config Signed-off-by: Shane A --- vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/olmo3.py | 77 +++++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 vllm/transformers_utils/configs/olmo3.py diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 95e4ed1ccf07..defb66284f00 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -75,6 +75,7 @@ def __getitem__(self, key): eagle="EAGLEConfig", speculators="SpeculatorsConfig", nemotron="NemotronConfig", + olmo3="Olmo3Config", ovis="OvisConfig", ultravox="UltravoxConfig", step3_vl="Step3VLConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f651ecb078b9..f0ffb9b40f88 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config +from vllm.transformers_utils.configs.olmo3 import Olmo3Config from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, @@ -44,6 +45,7 @@ "NemotronConfig", "NemotronHConfig", "Nemotron_Nano_VL_Config", + "Olmo3Config", "OvisConfig", "SpeculatorsConfig", "UltravoxConfig", diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py new file mode 100644 index 000000000000..b0e926d516ec --- /dev/null +++ b/vllm/transformers_utils/configs/olmo3.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268 + +from transformers.configuration_utils import PretrainedConfig + + +class Olmo3Config(PretrainedConfig): + + model_type = "olmo3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + sliding_window=4096, + layer_types=None, + **kwargs, + ): + if "architectures" not in kwargs: + kwargs["architectures"] = ["Olmo3ForCausalLM"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" + for i in range(self.num_hidden_layers) + ] From becfaaf18c2a4041e85403499af752fcafc0c9fd Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 9 Sep 2025 18:43:35 +0000 Subject: [PATCH 3/7] Add Olmo3 model implementation Signed-off-by: Shane A --- vllm/model_executor/models/olmo3.py | 63 +++++++++++++++----------- vllm/model_executor/models/registry.py | 1 + 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/olmo3.py b/vllm/model_executor/models/olmo3.py index bccd1b87043a..0bbd2de6336f 100644 --- a/vllm/model_executor/models/olmo3.py +++ b/vllm/model_executor/models/olmo3.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo3/modeling_olmo3.py # Copyright 2024 The vLLM team. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. # @@ -22,7 +22,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only OLMo2 model compatible with HuggingFace weights.""" +"""Inference-only Olmo3 model compatible with HuggingFace weights.""" from collections.abc import Iterable from functools import partial @@ -31,7 +31,6 @@ import torch from torch import nn -from transformers import Olmo2Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -52,13 +51,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( - AutoWeightsLoader, is_pp_missing_parameter, + AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Olmo3Config -class Olmo2Attention(nn.Module): +class Olmo3Attention(nn.Module): """ This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` @@ -68,7 +68,7 @@ class Olmo2Attention(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, Olmo3Config) hidden_size = self.config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() @@ -111,14 +111,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - # Rotary embeddings. - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, # type: ignore - ) self.scaling = self.head_dim**-0.5 + + layer_idx = extract_layer_index(prefix) + sliding_window = (self.config.sliding_window + if self.config.layer_types[layer_idx] + == "sliding_attention" else None) self.attn = Attention( self.num_heads, self.head_dim, @@ -126,7 +124,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): num_kv_heads=self.num_kv_heads, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, - prefix=prefix, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + # Rotary embeddings. Rope scaling is only applied on full attention + # layers. + self.rope_scaling = (self.config.rope_scaling + if sliding_window is None else None) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + rope_scaling=self.rope_scaling, ) # Attention output projection. @@ -166,7 +177,7 @@ def forward( return output -class Olmo2MLP(nn.Module): +class Olmo3MLP(nn.Module): """ This is the MLP block where the output is computed as ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` @@ -176,7 +187,7 @@ class Olmo2MLP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, Olmo3Config) hidden_size = config.hidden_size intermediate_size = config.intermediate_size @@ -211,7 +222,7 @@ def forward( return x -class Olmo2DecoderLayer(nn.Module): +class Olmo3DecoderLayer(nn.Module): """ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` @@ -221,13 +232,13 @@ class Olmo2DecoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, Olmo3Config) # Attention block. - self.self_attn = Olmo2Attention(vllm_config=vllm_config, + self.self_attn = Olmo3Attention(vllm_config=vllm_config, prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") + self.mlp = Olmo3MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") # LayerNorm self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -256,12 +267,12 @@ def forward( @support_torch_compile -class Olmo2Model(nn.Module): +class Olmo3Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, Olmo3Config) self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -270,7 +281,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, - lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, + lambda prefix: Olmo3DecoderLayer(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) @@ -357,7 +368,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): +class Olmo3ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ @@ -376,9 +387,9 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, Olmo3Config) self.config = config - self.model = Olmo2Model(vllm_config=vllm_config, + self.model = Olmo3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 43075956b450..490deb416dc5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -118,6 +118,7 @@ "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), + "Olmo3ForCausalLM": ("olmo3", "Olmo3ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), From 097ef6a6c9c3fab534b2188437388dc06dc2fbe1 Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 9 Sep 2025 20:41:48 +0000 Subject: [PATCH 4/7] Add Olmo3 to test registry and supported models Signed-off-by: Shane A --- docs/models/supported_models.md | 1 + tests/models/registry.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index d23fdff568fc..cfb48f4295ed 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -388,6 +388,7 @@ th { | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 755a37b109d7..4d941a81496a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -292,6 +292,7 @@ def check_available_online( trust_remote_code=True), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), + "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}), From ce477b1560cb392952c0ee2b3dbd4e2fb9407778 Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 9 Sep 2025 21:07:59 +0000 Subject: [PATCH 5/7] Minor comment updates Signed-off-by: Shane A --- vllm/model_executor/models/olmo3.py | 4 ++-- vllm/transformers_utils/configs/olmo3.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/olmo3.py b/vllm/model_executor/models/olmo3.py index 0bbd2de6336f..beefbb018c86 100644 --- a/vllm/model_executor/models/olmo3.py +++ b/vllm/model_executor/models/olmo3.py @@ -3,8 +3,8 @@ # Adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo3/modeling_olmo3.py -# Copyright 2024 The vLLM team. -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The vLLM team. +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py index b0e926d516ec..ae32c92dc5db 100644 --- a/vllm/transformers_utils/configs/olmo3.py +++ b/vllm/transformers_utils/configs/olmo3.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268 - from transformers.configuration_utils import PretrainedConfig From 61b1b26262a694f16b964e1f29ac8ff86c8e8536 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 10 Sep 2025 20:02:06 +0000 Subject: [PATCH 6/7] Support Olmo3 architecture in Olmo2 implementation Signed-off-by: Shane A --- vllm/model_executor/models/olmo2.py | 42 ++- vllm/model_executor/models/olmo3.py | 439 ----------------------- vllm/model_executor/models/registry.py | 2 +- vllm/transformers_utils/configs/olmo3.py | 7 +- 4 files changed, 35 insertions(+), 455 deletions(-) delete mode 100644 vllm/model_executor/models/olmo3.py diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index bccd1b87043a..3e4c580a1121 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -52,10 +52,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( - AutoWeightsLoader, is_pp_missing_parameter, + AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Olmo3Config class Olmo2Attention(nn.Module): @@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) hidden_size = self.config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() @@ -111,14 +112,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - # Rotary embeddings. - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, # type: ignore - ) self.scaling = self.head_dim**-0.5 + + layer_idx = extract_layer_index(prefix) + sliding_window = None + if ((layer_types := getattr(self.config, "layer_types", None)) + is not None and layer_types[layer_idx] == "sliding_attention"): + sliding_window = self.config.sliding_window + self.attn = Attention( self.num_heads, self.head_dim, @@ -126,7 +127,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): num_kv_heads=self.num_kv_heads, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, - prefix=prefix, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + # Rotary embeddings. Rope scaling is only applied on full attention + # layers. + self.rope_scaling = (self.config.rope_scaling + if sliding_window is None else None) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + rope_scaling=self.rope_scaling, ) # Attention output projection. @@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) hidden_size = config.hidden_size intermediate_size = config.intermediate_size @@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) # Attention block. self.self_attn = Olmo2Attention(vllm_config=vllm_config, prefix=f"{prefix}.self_attn") @@ -261,7 +275,7 @@ class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) self.config = config self.model = Olmo2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) diff --git a/vllm/model_executor/models/olmo3.py b/vllm/model_executor/models/olmo3.py deleted file mode 100644 index beefbb018c86..000000000000 --- a/vllm/model_executor/models/olmo3.py +++ /dev/null @@ -1,439 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo3/modeling_olmo3.py -# Copyright 2025 The vLLM team. -# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Olmo3 model compatible with HuggingFace weights.""" - -from collections.abc import Iterable -from functools import partial -from itertools import islice -from typing import Optional, Union - -import torch -from torch import nn - -from vllm.attention import Attention -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.distributed.communication_op import tensor_model_parallel_all_gather -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.distributed.utils import split_tensor_along_last_dim -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.utils import ( - AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs import Olmo3Config - - -class Olmo3Attention(nn.Module): - """ - This is the attention block where the output is computed as - ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` - (plus another skip connection). - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo3Config) - - hidden_size = self.config.hidden_size - self.tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = self.config.num_attention_heads - - assert hidden_size % self.total_num_heads == 0 - assert self.total_num_heads % self.tp_size == 0 - - self.num_heads = self.total_num_heads // self.tp_size - self.total_num_kv_heads = (self.config.num_key_value_heads - or self.total_num_heads) - if self.total_num_kv_heads >= self.tp_size: - assert self.total_num_kv_heads % self.tp_size == 0 - else: - assert self.tp_size % self.total_num_kv_heads == 0 - - self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.max_position_embeddings = self.config.max_position_embeddings - self.rope_theta = self.config.rope_theta - - # Attention input projection. Projects x -> (q, k, v) - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=vllm_config.quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.tp_rank = get_tensor_model_parallel_rank() - self.k_norm = RMSNorm( - self.total_num_kv_heads * self.head_dim, - eps=self.config.rms_norm_eps, - ) - self.q_norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) - - self.scaling = self.head_dim**-0.5 - - layer_idx = extract_layer_index(prefix) - sliding_window = (self.config.sliding_window - if self.config.layer_types[layer_idx] - == "sliding_attention" else None) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn", - ) - - # Rotary embeddings. Rope scaling is only applied on full attention - # layers. - self.rope_scaling = (self.config.rope_scaling - if sliding_window is None else None) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, # type: ignore - rope_scaling=self.rope_scaling, - ) - - # Attention output projection. - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=vllm_config.quant_config, - prefix=f"{prefix}.o_proj", - ) - - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - if self.tp_size > 1: - q = tensor_model_parallel_all_gather(q.contiguous()) - k = tensor_model_parallel_all_gather(k.contiguous()) - q = self.q_norm(q) - k = self.k_norm(k) - if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) - q = splitter(q)[self.tp_rank] - k = splitter(k)[self.tp_rank] - return q, k - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class Olmo3MLP(nn.Module): - """ - This is the MLP block where the output is computed as - ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` - (plus another skip connection). - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo3Config) - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size - - # Feed-forward input projection. - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=vllm_config.quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - - # Activation function. - self.act_fn = SiluAndMul() - - # Feed-forward output projection. - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=vllm_config.quant_config, - prefix=f"{prefix}.down_proj", - ) - - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class Olmo3DecoderLayer(nn.Module): - """ - This is a typical transformer block where the output is - computed as ``MLP(LN(x + Attention(LN(x))))`` - (plus another skip connection). - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo3Config) - # Attention block. - self.self_attn = Olmo3Attention(vllm_config=vllm_config, - prefix=f"{prefix}.self_attn") - - # MLP block. - self.mlp = Olmo3MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") - - # LayerNorm - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - # Attention block. - residual = hidden_states - hidden_states = self.self_attn(positions, hidden_states) - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = hidden_states + residual - - # MLP block. - residual = hidden_states - hidden_states = self.mlp(hidden_states) - hidden_states = self.post_feedforward_layernorm(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -@support_torch_compile -class Olmo3Model(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo3Config) - - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - prefix=f"{prefix}.embed_tokens", - ) - self.start_layer, self.end_layer, self.layers = make_layers( - self.config.num_hidden_layers, - lambda prefix: Olmo3DecoderLayer(vllm_config=vllm_config, - prefix=prefix), - prefix=f"{prefix}.layers", - ) - self.norm = RMSNorm( - self.config.hidden_size, - eps=self.config.rms_norm_eps, - ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - self.config.hidden_size)) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - """ - :param input_ids: A tensor of shape `(batch_size, seq_len)`. - """ - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - # Get embeddings of input. - # shape: (batch_size, seq_len, d_model) - else: - hidden_states = self.embed_tokens(input_ids) - - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - assert isinstance(hidden_states, torch.Tensor) - - # Apply blocks one-by-one. - for layer in islice(self.layers, self.start_layer, self.end_layer): - # shape: (batch_size, seq_len, d_model) - hidden_states = layer(positions, hidden_states) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - - # Apply final layer norm. - # shape: (batch_size, seq_len or 1, d_model) - hidden_states = self.norm(hidden_states) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if is_pp_missing_parameter(name, self): - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader # type: ignore - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Olmo3ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - """ - Extremely barebones HF model wrapper. - """ - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo3Config) - self.config = config - self.model = Olmo3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.unpadded_vocab_size = config.vocab_size - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - quant_config=vllm_config.quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 490deb416dc5..40d4c735319d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -118,7 +118,7 @@ "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), - "Olmo3ForCausalLM": ("olmo3", "Olmo3ForCausalLM"), + "Olmo3ForCausalLM": ("olmo3", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py index ae32c92dc5db..874507db43a7 100644 --- a/vllm/transformers_utils/configs/olmo3.py +++ b/vllm/transformers_utils/configs/olmo3.py @@ -34,8 +34,13 @@ def __init__( layer_types=None, **kwargs, ): + # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM + # in vLLM. if "architectures" not in kwargs: - kwargs["architectures"] = ["Olmo3ForCausalLM"] + kwargs["architectures"] = ["Olmo2ForCausalLM"] + elif "Olmo3ForCausalLM" in kwargs["architectures"]: + kwargs["architectures"].remove("Olmo3ForCausalLM") + kwargs["architectures"].append("Olmo2ForCausalLM") super().__init__( pad_token_id=pad_token_id, From 9542485af88858be4dd5585f33ab52a9831f9959 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 12 Sep 2025 17:23:11 +0000 Subject: [PATCH 7/7] Fix registry for Olmo3 model Signed-off-by: Shane A --- vllm/model_executor/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 44fb592e624a..85759df36985 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -120,7 +120,7 @@ "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), - "Olmo3ForCausalLM": ("olmo3", "Olmo2ForCausalLM"), + "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"),