Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -58,6 +58,14 @@
logger = init_logger(__name__)


def _get_norm_cls(
quant_config: QuantizationConfig | None,
) -> type[nn.Module]:
"""Return RMSNorm for GGUF (weights have +1 baked in), else GemmaRMSNorm."""
quant_name = quant_config.get_name() if quant_config else None
return RMSNorm if quant_name == "gguf" else GemmaRMSNorm


class Gemma2MLP(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -214,16 +222,12 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
rms_norm_cls = _get_norm_cls(quant_config)
rms_norm_kwargs = dict(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_attention_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.pre_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)

def forward(
self,
Expand Down Expand Up @@ -263,6 +267,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
Expand All @@ -271,7 +277,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
),
prefix=f"{prefix}.layers",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = _get_norm_cls(quant_config)(
config.hidden_size, eps=config.rms_norm_eps
)

# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
Expand Down Expand Up @@ -361,6 +369,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
if is_pp_missing_parameter(name, self):
continue
# Skip parameters not in the model (e.g., GGUF quantization
# metadata like qweight_type for embeddings)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
Expand Down
44 changes: 21 additions & 23 deletions vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
Attention,
EncoderOnlyAttention,
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -64,6 +64,14 @@
logger = init_logger(__name__)


def _get_norm_cls(
quant_config: QuantizationConfig | None,
) -> type[nn.Module]:
"""Return RMSNorm for GGUF (weights have +1 baked in), else GemmaRMSNorm."""
quant_name = quant_config.get_name() if quant_config else None
return RMSNorm if quant_name == "gguf" else GemmaRMSNorm


class Gemma3MLP(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -156,8 +164,9 @@ def __init__(
prefix=f"{prefix}.o_proj",
)

self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
rms_norm_cls = _get_norm_cls(quant_config)
self.q_norm = rms_norm_cls(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = rms_norm_cls(self.head_dim, eps=config.rms_norm_eps)

layer_idx = extract_layer_index(prefix)
layer_type = config.layer_types[layer_idx]
Expand Down Expand Up @@ -261,16 +270,12 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
rms_norm_cls = _get_norm_cls(quant_config)
rms_norm_kwargs = dict(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_attention_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.pre_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)

def forward(
self,
Expand Down Expand Up @@ -322,7 +327,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
),
prefix=f"{prefix}.layers",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = _get_norm_cls(quant_config)(
config.hidden_size, eps=config.rms_norm_eps
)

# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
Expand Down Expand Up @@ -383,15 +390,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Revert +1 during llama.cpp conversion
# see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400
if (
self.quant_config
and self.quant_config.get_name() == "gguf"
and name.endswith("norm.weight")
):
loaded_weight -= 1

if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
Expand Down
12 changes: 12 additions & 0 deletions vllm/tokenizers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast

from vllm.logger import init_logger
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from vllm.transformers_utils.gguf_utils import maybe_patch_gguf_tokenizer

from .protocol import TokenizerLike

HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast

logger = init_logger(__name__)


def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
"""
Expand Down Expand Up @@ -81,6 +85,9 @@ def from_pretrained(
download_dir: str | None = None,
**kwargs,
) -> HfTokenizer:
# Save gguf_file before AutoTokenizer.from_pretrained() pops it from kwargs
gguf_file = kwargs.get("gguf_file")

try:
tokenizer = AutoTokenizer.from_pretrained(
path_or_repo_id,
Expand Down Expand Up @@ -122,4 +129,9 @@ def from_pretrained(
}
tokenizer.add_special_tokens(special_tokens_map)

# Patch tokenizer EOS from GGUF metadata when applicable
# (gguf_file was saved above before AutoTokenizer.from_pretrained()
# popped it from kwargs).
maybe_patch_gguf_tokenizer(tokenizer, path_or_repo_id, gguf_file)

return get_cached_tokenizer(tokenizer)
97 changes: 97 additions & 0 deletions vllm/transformers_utils/gguf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,103 @@ def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | N
return config


def extract_eos_token_id_from_gguf(model: str) -> int | None:
"""Extract EOS token ID from GGUF metadata.

GGUF files store the EOS token ID in tokenizer.ggml.eos_token_id field.
This may differ from HuggingFace's tokenizer config (e.g., Gemma models
use <end_of_turn> token ID 106 as EOS in GGUF, but HF tokenizer reports
<eos> token ID 1).

Args:
model: Path to GGUF model file

Returns:
EOS token ID from GGUF metadata, or None if not found
"""
# Note: We don't check for .gguf extension here because HuggingFace Hub
# stores GGUF files as blob hashes without extensions. The caller is
# responsible for ensuring this is a valid GGUF file (via check_gguf_file).
try:
model_path = Path(model)
if not model_path.is_file():
return None

reader = gguf.GGUFReader(str(model_path))

eos_field = reader.get_field(Keys.Tokenizer.EOS_ID)
if eos_field is not None:
eos_token_id = int(eos_field.parts[-1][0])
logger.debug(
"Extracted eos_token_id=%d from GGUF metadata",
eos_token_id,
)
return eos_token_id

return None

except Exception as e:
logger.debug("Error extracting EOS token ID from GGUF: %s", e)
return None


def maybe_patch_gguf_tokenizer(
tokenizer,
path_or_repo_id: str | Path,
gguf_file: str | None,
) -> None:
"""Patch ``tokenizer.eos_token_id`` from GGUF metadata when available.

GGUF files may store a different EOS token ID than the HuggingFace
tokenizer config (e.g., Gemma uses ``<end_of_turn>`` ID 106 as EOS in
GGUF, but HF reports ``<eos>`` ID 1). This helper mutates
``tokenizer.eos_token_id`` in place to match the GGUF metadata.

Accepts either a local directory path or a HuggingFace repo id for
``path_or_repo_id``; remote repo ids resolve to a local file via
``hf_hub_download``.

Safe to call unconditionally — this is a no-op when ``gguf_file`` is
falsy, the file cannot be resolved, or the GGUF metadata has no EOS
token id.
"""
if not gguf_file:
return

candidate = Path(path_or_repo_id) / gguf_file
if candidate.is_file():
gguf_path = candidate
else:
# Treat path_or_repo_id as a HuggingFace repo id
from huggingface_hub import hf_hub_download

try:
gguf_path = Path(
hf_hub_download(repo_id=str(path_or_repo_id), filename=gguf_file)
)
except Exception as e:
logger.debug(
"Could not resolve GGUF file %s from %s: %s",
gguf_file,
path_or_repo_id,
e,
)
return

gguf_eos_id = extract_eos_token_id_from_gguf(str(gguf_path))
if gguf_eos_id is None:
return

hf_eos_id = tokenizer.eos_token_id
if hf_eos_id != gguf_eos_id:
logger.info(
"Patching tokenizer eos_token_id from %d to %d (using GGUF metadata)",
hf_eos_id,
gguf_eos_id,
)
tokenizer.eos_token_id = gguf_eos_id


def maybe_patch_hf_config_from_gguf(
model: str,
hf_config: PretrainedConfig,
Expand Down
Loading