Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 310 additions & 1 deletion convert_hf_to_gguf.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
#define GGML_ROPE_TYPE_ERNIE3D 72 // binary: 1001000, ERNIE-VL 3D RoPE (NORMAL rotation + interleaved h/w freq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ROPE_TYPE system is quite fragile now and I think we should always reflect twice before adding a new mode.

It seems like interleaved h/w freq is already supported by Pixtral model, please verify one more time if you can reuse the code from Pixtral instead of adding a new rope kernel here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads-up. I completely agree that we should be cautious with the ROPE_TYPE system. I’ll re-examine the Pixtral implementation to see if we can reuse its interleaved frequency logic instead of adding a new kernel.

Copy link
Author

@isLinXu isLinXu Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I’ve conducted a detailed mathematical comparison between Pixtral’s build_rope_2d and the ERNIE implementation. It turns out they are mathematically incompatible, and direct reuse would result in incorrect positional embeddings.

Below is the technical breakdown:

Feature Pixtral build_rope_2d ERNIE (Vision / LLM)
Rotation Mode NORMAL (Adjacent pairs) NEOX (Half-dimension offset)
Freq. Allocation 2-way Interleaved (via freq_scale_odd) Sectional (2D) / 3-way Interleaved (3D)
Theta Accumulation Continuous across the head Independent reset per section
Dimensionality 2D (h, w) only 3D (t, h, w)
Implementation Dual rope_ext + concat ggml_rope_multi with mrope 4-slot

Key Technical Differences:

  1. Mathematical Incompatibility: Pixtral uses NORMAL rotation, whereas ERNIE follows the NEOX convention (commonly used in Vision Transformers). Since the pairing of dimensions differs, swapping them would break the model's spatial understanding.
  2. Frequency Mapping: Pixtral achieves interleaved frequencies by applying a freq_scale to one-half of the dimensions. ERNIE uses sections [20, 20, 0, 0] to strictly block frequencies, where each section starts its theta accumulation independently from $base^0$.

Regarding the complexity of the ROPE_TYPE system:

  • Vision Side: We are actually using the existing GGML_ROPE_TYPE_VISION. No new mode is introduced here.
  • LLM Side: The new GGML_ROPE_TYPE_ERNIE3D is a strict requirement to support the Temporal (t) dimension. Current 2D implementations (like Pixtral) cannot handle this 3D mapping.

Conclusion:

To maintain mathematical correctness and support 3D RoPE, we cannot reuse the Pixtral logic. The new ERNIE3D type is the minimum necessary change to support these specific requirements. I will ensure the implementation is as modular as possible to keep the system maintainable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the difference is just the normal and neox style, you can also permute the Q and K tensor upon converting to GGUF.

Kimi 2.5 also do exactly this, you can copy the conversion code from #19170

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just a friendly reminder: We don't allow replying to human maintainer with AI-generated response. Please write the response with your own writing,to prove that you fully understand your code


#define GGML_MROPE_SECTIONS 4

Expand Down
48 changes: 47 additions & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5651,6 +5651,43 @@ static void rotate_pairs(const int64_t n, const int64_t n_offset, const float *
}
}

static void ggml_ernie3d_rope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w,
int sections[4],
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// n_hw = sections[0] + sections[1] = total number of interleaved h/w frequencies
int n_hw = sections[0] + sections[1];

float theta_accum = 1.0f; // accumulated theta_scale^freq_idx

for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
int freq_idx = (int)(i0 / 2);
const float ff = freq_factors ? freq_factors[freq_idx] : 1.0f;

float theta;
if (freq_idx < n_hw) {
if (freq_idx % 2 == 0) {
// even freq index -> height position
theta = theta_base_h * theta_accum;
} else {
// odd freq index -> width position
theta = theta_base_w * theta_accum;
}
} else {
// temporal position
theta = theta_base_t * theta_accum;
}

rope_yarn(
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
);
cache[i0 + 1] *= sin_sign;

theta_accum *= theta_scale;
}
}

template<typename T> //float or ggml_fp16_t
static void ggml_compute_forward_rope_flt(
const ggml_compute_params * params,
Expand Down Expand Up @@ -5723,7 +5760,7 @@ static void ggml_compute_forward_rope_flt(
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}

const bool is_ernie3d = mode == GGML_ROPE_TYPE_ERNIE3D;
const float * freq_factors = NULL;
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
Expand All @@ -5745,6 +5782,14 @@ static void ggml_compute_forward_rope_flt(
if (!mrope_used) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
} else if (is_ernie3d) {
// ERNIE-VL 3D RoPE: interleaved h/w freq with NORMAL rotation
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2];
ggml_ernie3d_rope_cache_init(
p_t, p_h, p_w, sections,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i2];
Expand All @@ -5765,6 +5810,7 @@ static void ggml_compute_forward_rope_flt(

switch (mode) {
case GGML_ROPE_TYPE_NORMAL:
case GGML_ROPE_TYPE_ERNIE3D:
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
break;
case GGML_ROPE_TYPE_NEOX:
Expand Down
100 changes: 99 additions & 1 deletion ggml/src/ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,68 @@ static __global__ void rope_multi(const T * x,
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}

template<bool forward, bool has_ff, typename T>
static __global__ void rope_ernie3d(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (i0 >= ne0) {
return;
}

const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;

const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;

// NORMAL rotation: pair (x[i0], x[i0+1]), stored at adjacent positions
const int idst = row_dst*ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;

if (i0 >= n_dims) {
dst[idst + 0] = x[ix + 0];
dst[idst + 1] = x[ix + 1];
return;
}

// freq_idx = i0/2 (which frequency pair this is)
const int freq_idx = i0 / 2;
// n_hw = sections[0] + sections[1] = total number of h+w interleaved frequencies
const int n_hw = sections.v[0] + sections.v[1];

// Determine which position slot to use based on interleaved pattern
// Position slots: slot 0 = t_position, slot 1 = h_position, slot 2 = w_position
float theta_base = 0.0f;
if (freq_idx < n_hw) {
if (freq_idx % 2 == 0) {
// even freq index -> height position (slot 1)
theta_base = pos[channel_x + ne2 * 1] * powf(theta_scale, (float)freq_idx);
} else {
// odd freq index -> width position (slot 2)
theta_base = pos[channel_x + ne2 * 2] * powf(theta_scale, (float)freq_idx);
}
} else {
// temporal position (slot 0)
theta_base = pos[channel_x] * powf(theta_scale, (float)freq_idx);
}

const float freq_factor = has_ff ? freq_factors[freq_idx] : 1.0f;

float cos_theta;
float sin_theta;

rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);

// NORMAL (GPT-J) rotation: adjacent pair (x[i0], x[i0+1])
const float x0 = x[ix + 0];
const float x1 = x[ix + 1];

dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}

template <bool forward, bool has_ff, typename T>
static __global__ void rope_vision(const T * x,
T * dst,
Expand Down Expand Up @@ -453,6 +515,29 @@ static void rope_multi_cuda(const T * x,
}
}

template<bool forward, typename T>
static void rope_ernie3d_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);

const float theta_scale = powf(freq_base, -2.0f/n_dims);

if (freq_factors == nullptr) {
rope_ernie3d<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else {
rope_ernie3d<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
}
}

template <bool forward, typename T>
static void rope_vision_cuda(const T * x,
T * dst,
Expand Down Expand Up @@ -603,7 +688,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else {
} else if (is_ernie3d) {
if (src0->type == GGML_TYPE_F32) {
rope_ernie3d_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_ernie3d_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else {
GGML_ABORT("fatal error");
}
}
else {
GGML_ABORT("fatal error");
}
} else if (is_mrope && !is_vision) {
Expand Down
47 changes: 47 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class LLM:
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
VISION_EXPERT_FEED_FORWARD_LENGTH = "{arch}.vision_expert_feed_forward_length"
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
Expand Down Expand Up @@ -447,6 +448,7 @@ class MODEL_ARCH(IntEnum):
AFMOE = auto()
ERNIE4_5 = auto()
ERNIE4_5_MOE = auto()
ERNIE4_5_VL_MOE = auto()
HUNYUAN_MOE = auto()
HUNYUAN_DENSE = auto()
SMOLLM3 = auto()
Expand Down Expand Up @@ -723,6 +725,17 @@ class MODEL_TENSOR(IntEnum):
V_DS_NORM = auto() # qwen3vl
V_DS_FC1 = auto() # qwen3vl
V_DS_FC2 = auto() # qwen3vl
V_FFN_GATE_INP = auto() # ernie45vlmoe
V_FFN_UP_EXPS = auto() # ernie45vlmoe
V_FFN_DOWN_EXPS = auto() # ernie45vlmoe
V_FFN_NORM_EXPS = auto() # ernie45vlmoe
V_FFN_GATE_EXPS = auto() # ernie45vlmoe
V_FFN_GATE_SHEXP = auto() # ernie45vlmoe
V_FFN_UP_SHEXP = auto() # ernie45vlmoe
V_FFN_DOWN_SHEXP = auto() # ernie45vlmoe
V_FFN_GATE_INP_SHEXP = auto() # ernie45vlmoe
V_FFN_NORM_SHEXP = auto() # ernie45vlmoe
V_FFN_EXP_PROBS_B = auto() # ernie45vlmoe
V_MM_POST_FC_NORM = auto() # cogvlm
V_MM_UP = auto() # cogvlm
V_MM_DOWN = auto() # cogvlm
Expand Down Expand Up @@ -879,6 +892,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.AFMOE: "afmoe",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
MODEL_ARCH.ERNIE4_5_VL_MOE: "ernie4_5-vl-moe",
MODEL_ARCH.FALCON_H1: "falcon-h1",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
Expand Down Expand Up @@ -1159,6 +1173,11 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MM_GATE: "mm.gate",
MODEL_TENSOR.V_TOK_BOI: "v.boi",
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
MODEL_TENSOR.V_FFN_GATE_INP: "blk.{bid}.v_ffn_gate_inp",
MODEL_TENSOR.V_FFN_GATE_EXPS: "blk.{bid}.v_ffn_gate_exps",
MODEL_TENSOR.V_FFN_DOWN_EXPS: "blk.{bid}.v_ffn_down_exps",
MODEL_TENSOR.V_FFN_UP_EXPS: "blk.{bid}.v_ffn_up_exps",
MODEL_TENSOR.V_FFN_EXP_PROBS_B: "blk.{bid}.v_exp_probs_b",
# audio (mtmd)
# note: all audio tensor names must use prefix "a." or "mm.a."
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
Expand Down Expand Up @@ -2597,6 +2616,33 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.ERNIE4_5_VL_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.V_FFN_GATE_INP,
MODEL_TENSOR.V_FFN_GATE_EXPS,
MODEL_TENSOR.V_FFN_DOWN_EXPS,
MODEL_TENSOR.V_FFN_UP_EXPS,
MODEL_TENSOR.V_FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.PLM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
Expand Down Expand Up @@ -3770,6 +3816,7 @@ class VisionProjectorType:
MUSIC_FLAMINGO = "musicflamingo" # audio
GLM4V = "glm4v"
YOUTUVL = "youtuvl"
ERNIE45VLMOE = "ernie4.5vl_moe"


# Items here are (block size, type size)
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,9 @@ def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
def add_expert_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_vision_expert_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.VISION_EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_expert_shared_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)

Expand Down
Loading