Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3728,6 +3728,13 @@ class Ernie4_5Model(TextModel):
def set_vocab(self):
self._set_vocab_sentencepiece()

tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])

def set_gguf_parameters(self):
super().set_gguf_parameters()

Expand All @@ -3737,6 +3744,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if (head_dim := self.hparams.get("head_dim")) is None:
head_dim = self.hparams["hidden_size"] // num_heads

if "mlp_AR" in name or "vision_model" in name:
# skip vision model and projector tensors
return []

if "ernie." in name:
name = name.replace("ernie.", "model.")
# split the qkv weights
Expand Down Expand Up @@ -3855,6 +3866,49 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("PaddleOCRVLForConditionalGeneration")
class PaddleOCRModel(Ernie4_5Model):
model_arch = gguf.MODEL_ARCH.PADDLEOCR


@ModelBase.register("PaddleOCRVisionModel")
class PaddleOCRVisionModel(MmprojModel):
# PaddleOCR-VL uses a modified version of Siglip
min_pixels: int = 0
max_pixels: int = 0

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_pixels = self.preprocessor_config["min_pixels"]
self.max_pixels = self.preprocessor_config["max_pixels"]
self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels))

def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR)
self.gguf_writer.add_vision_max_pixels(self.max_pixels)
self.gguf_writer.add_vision_min_pixels(self.min_pixels)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
name = name.replace("visual.", "model.")

if "vision_model" in name or "mlp_AR" in name:
if "packing_position_embedding" in name:
return [] # unused
elif "vision_model.head" in name:
# we don't yet support image embeddings for this model
return []
else:
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors


@ModelBase.register(
"Qwen2VLModel",
"Qwen2VLForConditionalGeneration",
Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ class Clip:
class ClipVision:
PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models
IMAGE_SIZE = "clip.vision.image_size"
MAX_PIXELS = "clip.vision.max_pixels"
MIN_PIXELS = "clip.vision.min_pixels"
PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
Expand Down Expand Up @@ -456,6 +458,7 @@ class MODEL_ARCH(IntEnum):
RND1 = auto()
PANGU_EMBED = auto()
MISTRAL3 = auto()
PADDLEOCR = auto()
MIMO2 = auto()
LLAMA_EMBED = auto()
MAINCODER = auto()
Expand Down Expand Up @@ -877,6 +880,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.PADDLEOCR: "paddleocr",
MODEL_ARCH.MIMO2: "mimo2",
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
MODEL_ARCH.MAINCODER: "maincoder",
Expand Down Expand Up @@ -3016,6 +3020,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PADDLEOCR: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.FALCON_H1: [
# Token embedding
MODEL_TENSOR.TOKEN_EMBD,
Expand Down Expand Up @@ -3610,6 +3628,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
PADDLEOCR = "paddleocr"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,12 @@ def add_vision_patch_size(self, value: int) -> None:
def add_vision_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)

def add_vision_max_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.MAX_PIXELS, value)

def add_vision_min_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.MIN_PIXELS, value)

def add_vision_feed_forward_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"visual.merger.mlp.{bid}", # qwen2vl
"mlp_AR.linear_{bid}", # PaddleOCR-VL
"merger.mlp.{bid}",
),

Expand Down Expand Up @@ -1492,6 +1493,7 @@ class TensorNameMap:
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
"mlp_AR.pre_norm", # PaddleOCR-VL
"merger.ln_q",
),

Expand All @@ -1517,6 +1519,7 @@ class TensorNameMap:

MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
"resampler.attn.out_proj",
"model.vision_model.head.attention.out_proj",
),

MODEL_TENSOR.V_RESMPL_KV: (
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ add_library(llama
models/dream.cpp
models/ernie4-5-moe.cpp
models/ernie4-5.cpp
models/paddleocr.cpp
models/exaone.cpp
models/exaone4.cpp
models/exaone-moe.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
{ LLM_ARCH_MIMO2, "mimo2" },
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_MAINCODER, "maincoder" },
Expand Down Expand Up @@ -710,6 +711,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
case LLM_ARCH_INTERNLM2:
case LLM_ARCH_GRANITE:
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_PADDLEOCR:
case LLM_ARCH_SMOLLM3:
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ enum llm_arch {
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_PADDLEOCR,
LLM_ARCH_MIMO2,
LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_MAINCODER,
Expand Down
10 changes: 10 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2167,7 +2167,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_PADDLEOCR:
{
// paddleocr need mrope_section
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);

ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
if (arch == LLM_ARCH_ERNIE4_5_MOE) {
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
Expand Down Expand Up @@ -6271,6 +6275,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_PADDLEOCR:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

Expand Down Expand Up @@ -7995,6 +8000,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_ernie4_5_moe>(*this, params);
} break;
case LLM_ARCH_PADDLEOCR:
{
llm = std::make_unique<llm_build_paddleocr>(*this, params);
} break;
case LLM_ARCH_HUNYUAN_MOE:
{
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
Expand Down Expand Up @@ -8307,6 +8316,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
return LLAMA_ROPE_TYPE_NEOX;

case LLM_ARCH_QWEN2VL:
case LLM_ARCH_PADDLEOCR:
return LLAMA_ROPE_TYPE_MROPE;
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_QWEN3VLMOE:
Expand Down
1 change: 1 addition & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2418,6 +2418,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|calls|>" // solar-open
|| t.first == "<end_of_turn>"
|| t.first == "<|endoftext|>"
|| t.first == "</s>" // paddleocr
|| t.first == "<|eom_id|>"
|| t.first == "<EOT>"
|| t.first == "_<EOT>"
Expand Down
4 changes: 4 additions & 0 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_paddleocr : public llm_graph_context {
llm_build_paddleocr(const llama_model & model, const llm_graph_params & params);
};

template <bool iswa>
struct llm_build_exaone4 : public llm_graph_context {
llm_build_exaone4(const llama_model & model, const llm_graph_params & params);
Expand Down
119 changes: 119 additions & 0 deletions src/models/paddleocr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "models.h"

llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;

GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);

ggml_tensor * cur;
ggml_tensor * inpL;

inpL = build_inp_embd(model.tok_embd);

int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);

// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv();

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

// norm
{
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
}
// self-attention
{
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

Qcur = ggml_rope_multi(
ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

Kcur = ggml_rope_multi(
ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

// feed-forward network
{
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);

cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);

cur = build_cvec(cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}
cur = inpL;

cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_library(mtmd
models/llama4.cpp
models/llava.cpp
models/minicpmv.cpp
models/paddleocr.cpp
models/pixtral.cpp
models/qwen2vl.cpp
models/qwen3vl.cpp
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ enum projector_type {
PROJECTOR_TYPE_MUSIC_FLAMINGO,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_PADDLEOCR,
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO,
Expand Down Expand Up @@ -260,6 +261,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_PADDLEOCR, "paddleocr"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
Expand Down
Loading