Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"Gemma4ForCausalLM": "gemma",
"Gemma4UnifiedForConditionalGeneration": "gemma",
"GemmaForCausalLM": "gemma",
"Glm4ForCausalLM": "glm",
"Glm4MoeForCausalLM": "glm",
Expand Down Expand Up @@ -247,6 +248,7 @@
"Gemma3ForConditionalGeneration": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"Gemma4UnifiedForConditionalGeneration": "gemma",
"Glm4vForConditionalGeneration": "qwen3vl",
"Glm4vMoeForConditionalGeneration": "qwen3vl",
"GlmOcrForConditionalGeneration": "qwen3vl",
Expand Down
79 changes: 78 additions & 1 deletion conversion/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import re

from typing import Callable, Iterable, TYPE_CHECKING
from typing import Callable, Iterable, TYPE_CHECKING, Sequence

import torch

Expand Down Expand Up @@ -765,6 +765,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Gemma4UnifiedForConditionalGeneration")
class Gemma4UnifiedModel(Gemma4Model):
model_arch = gguf.MODEL_ARCH.GEMMA4

def _get_suppress_tokens(self) -> Sequence[int] | None:
gen_cfg_path = self.dir_model / "generation_config.json"
if gen_cfg_path.is_file():
with open(gen_cfg_path, encoding="utf-8") as f:
gen_cfg = json.load(f)
return gen_cfg.get("suppress_tokens")
return None

def set_gguf_parameters(self):
super().set_gguf_parameters()

suppress_tokens = self._get_suppress_tokens()
if suppress_tokens is not None:
self.gguf_writer.add_suppress_tokens(suppress_tokens)


@ModelBase.register("Gemma4ForConditionalGeneration")
class Gemma4VisionAudioModel(MmprojModel):
has_audio_encoder = True
Expand Down Expand Up @@ -839,3 +859,60 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
data_torch = data_torch.permute(0, 3, 1, 2).contiguous()
mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min"))
yield (mapped_name, data_torch)

@ModelBase.register("Gemma4UnifiedForConditionalGeneration")
class Gemma4UnifiedVisionAudioModel(Gemma4VisionAudioModel):
has_audio_encoder = True
has_vision_encoder = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
assert self.hparams_audio is not None
text_embd_dim = self.hparams_vision["mm_embed_dim"]
self.hparams_vision["hidden_size"] = text_embd_dim
self.hparams_audio["hidden_size"] = text_embd_dim
# this is a transformer-less vision tower, the params below are redundant but set to avoid error
self.hparams_vision["intermediate_size"] = 0
self.hparams_vision["num_layers"] = 0
self.hparams_vision["num_attention_heads"] = 0
self.hparams_audio["intermediate_size"] = 0
self.hparams_audio["num_layers"] = 0
self.hparams_audio["num_attention_heads"] = 0

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.GEMMA4UV)
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4UA)

def modify_tensors(self, data_torch, name, bid):
if name.endswith("pos_embedding"):
name += ".weight"
data_torch = data_torch.permute(1, 0, 2)
elif ".pos_norm." in name:
# rename to patch_ln3 to reuse the tensor name scheme
name = name.replace(".pos_norm.", ".patch_ln3.")
elif "patch_dense.weight" in name:
# ggml im2col outputs in RR..GG..BB.. (CHW) order, but weight expects RGBRGB.. (HWC).
# Permute columns so column i aligns with CHW input position i.
assert self.hparams_vision is not None
p = self.hparams_vision["model_patch_size"]
i = torch.arange(p * p * 3)
ch = i // (p * p)
row = (i % (p * p)) // p
col = i % p
# perm[i] = HWC column index for CHW position i
perm = row * p * 3 + col * 3 + ch
data_torch = data_torch[:, perm]
elif "patch_ln1.weight" in name or "patch_ln1.bias" in name:
# same permutation for patch_ln1 as patch_dense to align with CHW input order
assert self.hparams_vision is not None
p = self.hparams_vision["model_patch_size"]
i = torch.arange(p * p * 3)
ch = i // (p * p)
row = (i % (p * p)) // p
col = i % p
# perm[i] = HWC index for CHW position i
perm = row * p * 3 + col * 3 + ch
data_torch = data_torch[perm]
return super().modify_tensors(data_torch, name, bid)
6 changes: 6 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ class Tokenizer:
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
SUPPRESS_TOKENS = "tokenizer.ggml.suppress_tokens"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"
Expand Down Expand Up @@ -731,6 +732,7 @@ class MODEL_TENSOR(IntEnum):
V_ENC_EMBD_CLS = auto()
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_NORM = auto()
V_ENC_EMBD_PATCH_NORM = auto() # allow multiple norms in the same embd, e.g. for gemma4u
V_ENC_EMBD_POS = auto()
V_ENC_INPUT_NORM = auto()
V_ENC_ATTN_QKV = auto()
Expand Down Expand Up @@ -1250,6 +1252,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd",
MODEL_TENSOR.V_ENC_EMBD_PATCH_NORM: "v.patch_norm.{bid}",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
Expand Down Expand Up @@ -1431,6 +1434,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_NORM,
MODEL_TENSOR.V_ENC_EMBD_PATCH_NORM,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_EMBD_IMGNL,
MODEL_TENSOR.V_ENC_EMBD_VSEP,
Expand Down Expand Up @@ -4346,6 +4350,8 @@ class VisionProjectorType:
GEMMA3NA = "gemma3na"
GEMMA4V = "gemma4v"
GEMMA4A = "gemma4a"
GEMMA4UV = "gemma4uv" # "unified" variant
GEMMA4UA = "gemma4ua" # "unified" variant
PHI4 = "phi4"
IDEFICS3 = "idefics3"
PIXTRAL = "pixtral"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,9 @@ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:

self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)

def add_suppress_tokens(self, tokens: Sequence[int]) -> None:
self.add_array(Keys.Tokenizer.SUPPRESS_TOKENS, tokens)

def add_normalizer_lowercase(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.NORMALIZER_LOWERCASE, value)

Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,13 +1426,18 @@ class TensorNameMap:
"model.vision_tower.patch_embedder.input_proj", # gemma4
"vision_tower.patch_embed.patchifier.proj", # dots.ocr
"vision_model.conv1", # Step3-VL
"model.vision_embedder.patch_dense", # gemma4 unified
),

MODEL_TENSOR.V_ENC_EMBD_NORM: (
"visual.post_conv_layernorm", # glm4v
"vision_tower.patch_embed.patchifier.norm", # dots.ocr
),

MODEL_TENSOR.V_ENC_EMBD_PATCH_NORM: (
"model.vision_embedder.patch_ln{bid}", # gemma4 unified
),

MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding",
"model.vision_tower.embeddings.position_embedding", # minicpmv4_6
Expand All @@ -1448,6 +1453,7 @@ class TensorNameMap:
"vision_model.radio_model.model.patch_generator.pos_embed", # Nemotron Nano v2 VL
"model.vision_tower.patch_embedder.position_embedding_table", # gemma4
"vision_model.positional_embedding", # Step3-VL
"model.vision_embedder.pos_embedding", # gemma4 unified
),

MODEL_TENSOR.V_ENC_EMBD_IMGNL: (
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
{ LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" },

{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_FIM_PAD_ID,
LLM_KV_TOKENIZER_FIM_REP_ID,
LLM_KV_TOKENIZER_FIM_SEP_ID,
LLM_KV_TOKENIZER_SUPPRESS_TOKENS,

LLM_KV_ADAPTER_TYPE,
LLM_KV_ADAPTER_LORA_ALPHA,
Expand Down
16 changes: 16 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,8 @@ struct llama_vocab::impl {
// set of all tokens that cause "end of generation"
std::set<llama_token> special_eog_ids;

std::vector<llama_token> suppress_tokens;

std::unique_ptr<llm_tokenizer> tokenizer;

std::vector<char> precompiled_charsmap;
Expand Down Expand Up @@ -2533,6 +2535,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
// Lowercase normalizer flag (consulted by WPM / whitespace BPE)
ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false);

// suppress tokens
{
const int suppress_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SUPPRESS_TOKENS).c_str());
if (suppress_idx != -1) {
const int n = gguf_get_arr_n(ctx, suppress_idx);
const int32_t * data = (const int32_t *) gguf_get_arr_data(ctx, suppress_idx);
suppress_tokens.assign(data, data + n);
}
}

// auto-detect special tokens by text
// TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
// for now, we apply this workaround to find the tokens based on their text
Expand Down Expand Up @@ -3961,6 +3973,10 @@ bool llama_vocab::get_normalizer_lowercase() const {
return pimpl->normalizer_lowercase;
}

const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const {
return pimpl->suppress_tokens;
}

int llama_vocab::max_token_len() const {
return pimpl->max_token_len;
}
Expand Down
2 changes: 2 additions & 0 deletions src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ struct llama_vocab {
bool get_treat_whitespace_as_suffix() const;
bool get_normalizer_lowercase () const;

const std::vector<llama_token> & get_suppress_tokens() const;

int max_token_len() const;

int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
Expand Down
35 changes: 35 additions & 0 deletions src/models/gemma4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}

// TODO @ngxson : maybe improve this in the future
class llm_graph_input_logits_bias : public llm_graph_input_i {
public:
llm_graph_input_logits_bias(const llama_vocab & vocab) {
arr.resize(vocab.n_tokens(), 0.0f);
for (llama_token id : vocab.get_suppress_tokens()) {
if (0 <= id && id < (int32_t)vocab.n_tokens()) {
arr[id] = -INFINITY;
}
}
}
virtual ~llm_graph_input_logits_bias() = default;

void set_input(const llama_ubatch *) override {
const int64_t n_vocab = arr.size();
ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias));
}

// bool can_reuse(const llm_graph_params & params) override;

ggml_tensor * logits_bias = nullptr; // F32 [n_vocab]

std::vector<float> arr;
};

llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params),
model(model),
Expand Down Expand Up @@ -388,6 +413,16 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}

// apply logits bias if needed (e.g. for gemma4_unified patch)
// this is to mirror the suppress_tokens patch on transformers, to avoid model from outputing <image|> and <audio|> tokens (which is a known issue related to the checkpoint)
// TODO: maybe handle this inside the sampling system in the future
if (!model.vocab.get_suppress_tokens().empty()) {
auto inp_bias = std::make_unique<llm_graph_input_logits_bias>(model.vocab);
inp_bias->logits_bias = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, inp_bias->arr.size());
cur = ggml_add(ctx0, cur, inp_bias->logits_bias);
res->add_input(std::move(inp_bias));
}

cb(cur, "result_output", -1);
res->t_logits = cur;

Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ add_library(mtmd
models/exaone4_5.cpp
models/gemma4a.cpp
models/gemma4v.cpp
models/gemma4ua.cpp
models/gemma4uv.cpp
models/glm4v.cpp
models/granite-speech.cpp
models/hunyuanvl.cpp
Expand Down
5 changes: 5 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
#define TN_PATCH_BIAS "v.patch_embd.bias"
#define TN_NORM_EMBD "v.norm_embd.%s"
#define TN_PATCH_NORM "v.patch_norm.%d.%s"
#define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s"
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
Expand Down Expand Up @@ -317,6 +318,8 @@ enum projector_type {
PROJECTOR_TYPE_GEMMA3NA,
PROJECTOR_TYPE_GEMMA4V,
PROJECTOR_TYPE_GEMMA4A,
PROJECTOR_TYPE_GEMMA4UV,
PROJECTOR_TYPE_GEMMA4UA,
PROJECTOR_TYPE_PHI4,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
Expand Down Expand Up @@ -369,6 +372,8 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GEMMA3NA, "gemma3na"},
{ PROJECTOR_TYPE_GEMMA4V, "gemma4v"},
{ PROJECTOR_TYPE_GEMMA4A, "gemma4a"},
{ PROJECTOR_TYPE_GEMMA4UV, "gemma4uv"},
{ PROJECTOR_TYPE_GEMMA4UA, "gemma4ua"},
{ PROJECTOR_TYPE_PHI4, "phi4"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
Expand Down
8 changes: 8 additions & 0 deletions tools/mtmd/clip-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ struct clip_model {
ggml_tensor * norm_embd_w = nullptr;
ggml_tensor * norm_embd_b = nullptr;

// "indexed" patch embedding norms
ggml_tensor * patch_norm_1_w = nullptr;
ggml_tensor * patch_norm_1_b = nullptr;
ggml_tensor * patch_norm_2_w = nullptr;
ggml_tensor * patch_norm_2_b = nullptr;
ggml_tensor * patch_norm_3_w = nullptr;
ggml_tensor * patch_norm_3_b = nullptr;

ggml_tensor * pre_ln_w = nullptr;
ggml_tensor * pre_ln_b = nullptr;

Expand Down
Loading
Loading