Skip to content

Commit 7841fc7

Browse files
authored
llama : Add Gemma 3 support (+ experimental vision capability) (ggml-org#12343)
* llama : Add Gemma 3 text-only support * fix python coding style * fix compile on ubuntu * python: fix style * fix ubuntu compile * fix build on ubuntu (again) * fix ubuntu build, finally * clip : Experimental support for Gemma 3 vision (ggml-org#12344) * clip : Experimental support for Gemma 3 vision * fix build * PRId64
1 parent bf69cfe commit 7841fc7

11 files changed

+1202
-10
lines changed

convert_hf_to_gguf.py

+80
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,9 @@ def _create_vocab_sentencepiece(self):
861861
for token_id, token_data in added_tokens_decoder.items():
862862
token_id = int(token_id)
863863
token: str = token_data["content"]
864+
if token_id >= vocab_size:
865+
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
866+
continue
864867
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
865868
if tokens[token_id] != token.encode("utf-8"):
866869
logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}')
@@ -3322,6 +3325,83 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33223325
return [(self.map_tensor_name(name), data_torch)]
33233326

33243327

3328+
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
3329+
class Gemma3Model(Model):
3330+
model_arch = gguf.MODEL_ARCH.GEMMA3
3331+
has_vision: bool = False
3332+
3333+
# we need to merge the text_config into the root level of hparams
3334+
def __init__(self, *args, **kwargs):
3335+
hparams = Model.load_hparams(kwargs["dir_model"])
3336+
if "text_config" in hparams:
3337+
hparams = {**hparams, **hparams["text_config"]}
3338+
kwargs["hparams"] = hparams
3339+
super().__init__(*args, **kwargs)
3340+
if "vision_config" in hparams:
3341+
logger.info("Has vision encoder, but it will be ignored")
3342+
self.has_vision = True
3343+
3344+
def write(self):
3345+
super().write()
3346+
if self.has_vision:
3347+
logger.info("NOTE: this script only convert the language model to GGUF")
3348+
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
3349+
3350+
def set_vocab(self):
3351+
self._set_vocab_sentencepiece()
3352+
3353+
self.gguf_writer.add_add_space_prefix(False)
3354+
3355+
def set_gguf_parameters(self):
3356+
hparams = self.hparams
3357+
block_count = hparams["num_hidden_layers"]
3358+
3359+
# some default values are not specified in the hparams
3360+
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
3361+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
3362+
self.gguf_writer.add_block_count(block_count)
3363+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
3364+
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
3365+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
3366+
self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
3367+
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
3368+
self.gguf_writer.add_file_type(self.ftype)
3369+
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
3370+
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
3371+
assert hparams.get("attn_logit_softcapping") is None
3372+
assert hparams.get("final_logit_softcapping") is None
3373+
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
3374+
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
3375+
if hparams.get("rope_scaling") is not None:
3376+
assert hparams["rope_scaling"]["rope_type"] == "linear"
3377+
# important: this rope_scaling is only applied for global layers, and not used by 1B model
3378+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
3379+
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
3380+
3381+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3382+
del bid # unused
3383+
3384+
if name.startswith("language_model."):
3385+
name = name.replace("language_model.", "")
3386+
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
3387+
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
3388+
# ignore vision tensors
3389+
return []
3390+
3391+
# remove OOV (out-of-vocabulary) rows in token_embd
3392+
if "embed_tokens.weight" in name:
3393+
vocab = self._create_vocab_sentencepiece()
3394+
tokens = vocab[0]
3395+
data_torch = data_torch[:len(tokens)]
3396+
3397+
# ref code in Gemma3RMSNorm
3398+
# output = output * (1.0 + self.weight.float())
3399+
if name.endswith("norm.weight"):
3400+
data_torch = data_torch + 1
3401+
3402+
return [(self.map_tensor_name(name), data_torch)]
3403+
3404+
33253405
@Model.register("Starcoder2ForCausalLM")
33263406
class StarCoder2Model(Model):
33273407
model_arch = gguf.MODEL_ARCH.STARCODER2

examples/llava/CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ install(TARGETS ${TARGET} RUNTIME)
5151
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
5252
target_compile_features(${TARGET} PRIVATE cxx_std_17)
5353

54+
set(TARGET llama-gemma3-cli)
55+
add_executable(${TARGET} gemma3-cli.cpp)
56+
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
57+
install(TARGETS ${TARGET} RUNTIME)
58+
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
59+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
60+
5461
set(TARGET llama-llava-clip-quantize-cli)
5562
add_executable(${TARGET} clip-quantize-cli.cpp)
5663
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)

examples/llava/README-gemma3.md

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Gemma 3 vision
2+
3+
> [!IMPORTANT]
4+
>
5+
> This is very experimental, only used for demo purpose.
6+
7+
## How to get mmproj.gguf?
8+
9+
```bash
10+
cd gemma-3-4b-it
11+
python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
12+
13+
# output file is mmproj.gguf
14+
```
15+
16+
## How to run it?
17+
18+
What you need:
19+
- The text model GGUF, can be converted using `convert_hf_to_gguf.py`
20+
- The mmproj file from step above
21+
- An image file
22+
23+
```bash
24+
# build
25+
cmake -B build
26+
cmake --build build --target llama-gemma3-cli
27+
28+
# run it
29+
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
30+
```

0 commit comments

Comments
 (0)