Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11855,7 +11855,7 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this break the conversion of hunyuan-ocr ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully handled here, but needs testing:

llama.cpp/convert_hf_to_gguf.py

Lines 12071 to 12073 in 0a5a97c

@ModelBase.register("HunYuanVLForConditionalGeneration")
class HunyuanVLTextModel(HunYuanModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_VL

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson @CISC Addressed, tested locally: both HunyuanOCR and HunyuanVL convert to GGUF successfully and produce correct inference output on Metal (F16 / Q8_0). The only difference between OCR and VL is the projection dim (vision_config.out_hidden_size: 1024 for OCR)

@ModelBase.register("HunYuanDenseV1ForCausalLM")
class HunYuanModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE

Expand Down Expand Up @@ -12028,6 +12028,79 @@ def tensor_force_quant(self, name, new_name, bid, n_dims):
return super().tensor_force_quant(name, new_name, bid, n_dims)


@ModelBase.register("HunYuanVLForConditionalGeneration")
class HunyuanVLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# Compute image_size from max_image_size if not explicitly set
if "image_size" not in self.hparams_vision:
self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip text-model tensors (they go into the LLM gguf file)
if not name.startswith("vit."):
return
# strip CLS token (row 0) from position embeddings so resize_position_embeddings works
if "position_embedding" in name:
data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd]
yield from super().modify_tensors(data_torch, name, bid)

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int):
# Keep the final linear projection (mm.mlp.weight) in F16 to preserve precision
if new_name == "mm.mlp.weight":
return gguf.GGMLQuantizationType.F16
if ("mm.proj." in new_name) and new_name.endswith(".weight"):
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)

def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision

self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANVL)
self.gguf_writer.add_vision_use_gelu(True)

if (rms_norm_eps := hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
if (merge_size := hparams.get("spatial_merge_size")) is not None:
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))


@ModelBase.register("HunYuanVLForConditionalGeneration")
class HunyuanVLTextModel(HunYuanModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_VL

def set_gguf_parameters(self):
super().set_gguf_parameters()

if self.rope_parameters.get("rope_type") == "xdrope":
alpha = float(self.rope_parameters.get("alpha", 50))
base = float(self.rope_parameters.get("rope_theta", 10000.0))

# Write raw values; C++ computes: freq_base = base * alpha^(dim/(dim-2))
self.gguf_writer.add_rope_freq_base(base)
self.gguf_writer.add_rope_scaling_alpha(alpha)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(1)
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024)
self.gguf_writer.add_context_length(256 * 1024)

# xdrope_section defines which head-dim slices use each positional axis
# Reuse the M-RoPE rope_dimension_sections mechanism
xdrope_section = list(self.rope_parameters.get("xdrope_section", []))
while len(xdrope_section) < 4:
xdrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(xdrope_section[:4])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors — they are written by HunyuanVLVisionModel
if name.startswith("vit."):
return
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("SmolLM3ForCausalLM")
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
Expand Down
20 changes: 20 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class Rope:
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ALPHA = "{arch}.rope.scaling.alpha"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
Expand Down Expand Up @@ -471,6 +472,7 @@ class MODEL_ARCH(IntEnum):
ERNIE4_5_MOE = auto()
HUNYUAN_MOE = auto()
HUNYUAN_DENSE = auto()
HUNYUAN_VL = auto()
SMOLLM3 = auto()
GPT_OSS = auto()
LFM2 = auto()
Expand Down Expand Up @@ -957,6 +959,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.FALCON_H1: "falcon-h1",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
MODEL_ARCH.HUNYUAN_VL: "hunyuan_vl",
MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.GPT_OSS: "gpt-oss",
MODEL_ARCH.LFM2: "lfm2",
Expand Down Expand Up @@ -3489,6 +3492,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.HUNYUAN_VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.SMOLLM3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down Expand Up @@ -4138,6 +4157,7 @@ class VisionProjectorType:
YOUTUVL = "youtuvl"
NEMOTRON_V2_VL = "nemotron_v2_vl"
HUNYUANOCR = "hunyuanocr"
HUNYUANVL = "hunyuanvl_merger"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HUNYUANVL = "hunyuanvl_merger"
HUNYUANVL = "hunyuanvl"

remove _merger to make it shorter



# Items here are (block size, type size)
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,9 @@ def add_rope_scaling_type(self, value: RopeScalingType) -> None:
def add_rope_scaling_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)

def add_rope_scaling_alpha(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_ALPHA.format(arch=self.arch), value)

def add_rope_scaling_attn_factors(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)

Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
{ LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_LFM2, "lfm2" },
Expand Down Expand Up @@ -250,6 +251,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" },
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ enum llm_arch {
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_HUNYUAN_DENSE,
LLM_ARCH_HUNYUAN_VL,
LLM_ARCH_SMOLLM3,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_LFM2,
Expand Down Expand Up @@ -254,6 +255,7 @@ enum llm_kv {
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ALPHA,
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
Expand Down
1 change: 1 addition & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct llama_hparams {
float rope_freq_base_train_swa = 10000.0f;
float rope_freq_scale_train;
float rope_freq_scale_train_swa = 1.0f;
float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE

uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;
Expand Down
22 changes: 22 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);

if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) {
if (hparams.n_expert <= 1) {
hparams.n_expert = 0;
hparams.n_expert_used = 0;
}
}

if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl);
Expand Down Expand Up @@ -814,6 +821,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;

ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false);

// non-transformer models do not have attention heads
if (hparams.n_head() > 0) {
Expand Down Expand Up @@ -2591,9 +2599,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_HUNYUAN_VL:
case LLM_ARCH_HUNYUAN_DENSE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);

// XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2))
if (hparams.rope_scaling_alpha > 0.0f) {
const int dim = hparams.n_embd_head_k();
hparams.rope_freq_base_train = hparams.rope_freq_base_train
* powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2));
}

switch (hparams.n_embd) {
case 1024: type = LLM_TYPE_0_5B; break;
Expand Down Expand Up @@ -6946,6 +6963,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
}
} break;
case LLM_ARCH_HUNYUAN_VL:
case LLM_ARCH_HUNYUAN_DENSE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -8966,6 +8984,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
} break;
case LLM_ARCH_HUNYUAN_VL:
case LLM_ARCH_HUNYUAN_DENSE:
{
llm = std::make_unique<llm_build_hunyuan_dense>(*this, params);
Expand Down Expand Up @@ -9315,6 +9334,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GLM4_MOE:
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;

case LLM_ARCH_HUNYUAN_VL:
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;

// all model arches should be listed explicitly here
case LLM_ARCH_UNKNOWN:
GGML_ABORT("unknown architecture");
Expand Down
41 changes: 30 additions & 11 deletions src/models/hunyuan-dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);

const bool use_mrope = hparams.use_mrope();

int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);

ggml_tensor * cur;
ggml_tensor * inpL;

Expand Down Expand Up @@ -37,22 +42,36 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons
auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur,
n_embd_head, n_head, n_head_kv, il);

Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
if (use_mrope) {
Qcur = ggml_rope_multi(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

Kcur = ggml_rope_multi(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
} else {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
}

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

Kcur = build_norm(Kcur,
model.layers[il].attn_k_norm, nullptr,
LLM_NORM_RMS, il);
Expand Down
4 changes: 3 additions & 1 deletion tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
#define TN_TOK_BOI "v.boi"
#define TN_TOK_EOI "v.eoi"

// hunyuanocr
// hunyuanocr / hunyuanvl (shared GGUF tensor names)
#define TN_MM_PRE_NORM "mm.pre_norm.%s"
#define TN_TOK_IMG_BEGIN "mm.image_begin"
#define TN_TOK_IMG_END "mm.image_end"
Expand Down Expand Up @@ -293,6 +293,7 @@ enum projector_type {
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_NEMOTRON_V2_VL,
PROJECTOR_TYPE_HUNYUANOCR,
PROJECTOR_TYPE_HUNYUANVL,
PROJECTOR_TYPE_UNKNOWN,
};

Expand Down Expand Up @@ -338,6 +339,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
{ PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"},
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl_merger"},

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl_merger"},
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"},

};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
Loading