[model] add support for ERNIE 4.5 VL MoE#19475
[model] add support for ERNIE 4.5 VL MoE#19475isLinXu wants to merge 8 commits intoggml-org:masterfrom
Conversation
| #include <map> | ||
| #include <memory> | ||
| #include <string> | ||
| #include <set> |
There was a problem hiding this comment.
Please cleanup the llama-arch.h changes - remove all the whitespace modifications.
| #define GGML_ROPE_TYPE_MROPE 8 | ||
| #define GGML_ROPE_TYPE_VISION 24 | ||
| #define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000 | ||
| #define GGML_ROPE_TYPE_ERNIE3D 72 // binary: 1001000, ERNIE-VL 3D RoPE (NORMAL rotation + interleaved h/w freq) |
There was a problem hiding this comment.
the ROPE_TYPE system is quite fragile now and I think we should always reflect twice before adding a new mode.
It seems like interleaved h/w freq is already supported by Pixtral model, please verify one more time if you can reuse the code from Pixtral instead of adding a new rope kernel here.
There was a problem hiding this comment.
Thanks for the heads-up. I completely agree that we should be cautious with the ROPE_TYPE system. I’ll re-examine the Pixtral implementation to see if we can reuse its interleaved frequency logic instead of adding a new kernel.
There was a problem hiding this comment.
Thanks for the feedback. I’ve conducted a detailed mathematical comparison between Pixtral’s build_rope_2d and the ERNIE implementation. It turns out they are mathematically incompatible, and direct reuse would result in incorrect positional embeddings.
Below is the technical breakdown:
| Feature | Pixtral build_rope_2d | ERNIE (Vision / LLM) |
|---|---|---|
| Rotation Mode | NORMAL (Adjacent pairs) | NEOX (Half-dimension offset) |
| Freq. Allocation | 2-way Interleaved (via freq_scale_odd) |
Sectional (2D) / 3-way Interleaved (3D) |
| Theta Accumulation | Continuous across the head | Independent reset per section |
| Dimensionality | 2D (h, w) only | 3D (t, h, w) |
| Implementation | Dual rope_ext + concat |
ggml_rope_multi with mrope 4-slot |
Key Technical Differences:
- Mathematical Incompatibility: Pixtral uses NORMAL rotation, whereas ERNIE follows the NEOX convention (commonly used in Vision Transformers). Since the pairing of dimensions differs, swapping them would break the model's spatial understanding.
-
Frequency Mapping: Pixtral achieves interleaved frequencies by applying a
freq_scaleto one-half of the dimensions. ERNIE usessections [20, 20, 0, 0]to strictly block frequencies, where each section starts its theta accumulation independently from$base^0$ .
Regarding the complexity of the ROPE_TYPE system:
- Vision Side: We are actually using the existing
GGML_ROPE_TYPE_VISION. No new mode is introduced here. - LLM Side: The new
GGML_ROPE_TYPE_ERNIE3Dis a strict requirement to support the Temporal (t) dimension. Current 2D implementations (like Pixtral) cannot handle this 3D mapping.
Conclusion:
To maintain mathematical correctness and support 3D RoPE, we cannot reuse the Pixtral logic. The new ERNIE3D type is the minimum necessary change to support these specific requirements. I will ensure the implementation is as modular as possible to keep the system maintainable.
There was a problem hiding this comment.
If the difference is just the normal and neox style, you can also permute the Q and K tensor upon converting to GGUF.
Kimi 2.5 also do exactly this, you can copy the conversion code from #19170
There was a problem hiding this comment.
Also just a friendly reminder: We don't allow replying to human maintainer with AI-generated response. Please write the response with your own writing,to prove that you fully understand your code
tools/mtmd/clip-model.h
Outdated
| // ernie4.5-vl-moe | ||
| ggml_tensor * mm_spatial_0_w = nullptr; | ||
| ggml_tensor * mm_spatial_0_b = nullptr; | ||
| ggml_tensor * mm_spatial_2_w = nullptr; | ||
| ggml_tensor * mm_spatial_2_b = nullptr; | ||
| ggml_tensor * mm_spatial_norm_w = nullptr; | ||
| ggml_tensor * mm_spatial_norm_b = nullptr; | ||
| ggml_tensor * mm_temp_0_w = nullptr; | ||
| ggml_tensor * mm_temp_0_b = nullptr; | ||
| ggml_tensor * mm_temp_2_w = nullptr; | ||
| ggml_tensor * mm_temp_2_b = nullptr; | ||
| ggml_tensor * mm_temp_norm_w = nullptr; | ||
| ggml_tensor * mm_temp_norm_b = nullptr; | ||
| ggml_tensor * mm_mlp_w = nullptr; | ||
| ggml_tensor * mm_mlp_b = nullptr; | ||
| ggml_tensor * mm_after_norm_w = nullptr; |
There was a problem hiding this comment.
I don't think adding these tensors are needed.
Spatial patch merge is nothing new, we already support many models using the same strategy, please reuse the existing tensor naming and code infrastructure
tools/mtmd/models/ernie45vlmoe.cpp
Outdated
| ggml_tensor * spatial_0_w = ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_0_w)); | ||
| spatial_out = ggml_mul_mat(ctx0, spatial_0_w, spatial_out); | ||
| spatial_out = ggml_add(ctx0, spatial_out, model.mm_0_b); | ||
| cb(spatial_out, "spatial_linear_0", -1); | ||
|
|
||
| // GELU | ||
| spatial_out = ggml_gelu(ctx0, spatial_out); | ||
| cb(spatial_out, "spatial_gelu", -1); | ||
|
|
||
| // Second linear | ||
| ggml_tensor * spatial_2_w = ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_2_w)); | ||
| spatial_out = ggml_mul_mat(ctx0, spatial_2_w, spatial_out); | ||
| spatial_out = ggml_add(ctx0, spatial_out, model.mm_2_b); | ||
| cb(spatial_out, "spatial_linear_2", -1); |
There was a problem hiding this comment.
this can be reduced to build_ffn
tools/mtmd/models/ernie45vlmoe.cpp
Outdated
| ggml_tensor * spatial_out = embeddings; | ||
|
|
||
| // First linear | ||
| ggml_tensor * spatial_0_w = ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_0_w)); |
There was a problem hiding this comment.
any transposes to the weight must be done upon conversion
tools/mtmd/models/ernie45vlmoe.cpp
Outdated
| resampler_out = ggml_concat(ctx0, resampler_out, resampler_out, 0); | ||
|
|
||
| // Temporal linear path: Linear -> GELU -> Linear -> LayerNorm | ||
| // Weights were transposed (.t()) during GGUF conversion, undo with ggml_transpose |
There was a problem hiding this comment.
hmm, why transpose it when conversion, then having to transpose it again here?
unless my math is broken somehow, transpose(transpose(A)) is identical to just use A without any transposes
tools/mtmd/models/ernie45vlmoe.cpp
Outdated
| ggml_tensor * temp_0_w = ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_1_w)); | ||
| resampler_out = ggml_mul_mat(ctx0, temp_0_w, resampler_out); | ||
| resampler_out = ggml_add(ctx0, resampler_out, model.mm_1_b); | ||
| cb(resampler_out, "temporal_linear_0", -1); | ||
|
|
||
| // GELU | ||
| resampler_out = ggml_gelu(ctx0, resampler_out); | ||
| cb(resampler_out, "temporal_gelu", -1); | ||
|
|
||
| // Second temporal linear | ||
| ggml_tensor * temp_2_w = ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_3_w)); | ||
| resampler_out = ggml_mul_mat(ctx0, temp_2_w, resampler_out); | ||
| resampler_out = ggml_add(ctx0, resampler_out, model.mm_3_b); | ||
| cb(resampler_out, "temporal_linear_2", -1); |
| ggml_tensor * moe_out = nullptr; | ||
|
|
||
| // Use vision experts for vision tokens, text experts for text tokens | ||
| if (ubatch.embd) { |
There was a problem hiding this comment.
This may make the graph to be non-static, thus reduce the overall performance.
Instead, you should follow the same optimization with ggml_build_forward_select, see the example in llm_graph_context::build_inp_embd
| if (!ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false)) { | ||
| hparams.rope_sections[0] = 22; | ||
| hparams.rope_sections[1] = 22; | ||
| hparams.rope_sections[2] = 20; | ||
| hparams.rope_sections[3] = 0; | ||
| } |
There was a problem hiding this comment.
since we are not trying to be backward-compatible here, I think it's better not to hard code any default values. this metadata LLM_KV_ROPE_DIMENSION_SECTIONS must be a requirement
| hparams.rope_sections[3] = 0; | ||
| } | ||
|
|
||
| LLAMA_LOG_INFO("%s: ERNIE-VL rope_sections=[%d,%d,%d,%d]\n", __func__, |
There was a problem hiding this comment.
logging must not be done here
| case LLM_ARCH_ERNIE4_5_VL_MOE: | ||
| return LLAMA_ROPE_TYPE_ERNIE3D; |
There was a problem hiding this comment.
I wonder how this differ from qwen's IMROPE (the interleaved mrope version added in Qwen 3)
In any cases, I'm pretty sure that we don't need to add any new rope modes. Any more complex mode can be implemented in 2 ways (or combination of 2):
- Permute Q and K tensor upon conversion
- Frequency per-dimension can be controlled by the custom
freq_factors, i.e. thectensor inggml_rope_ext
There was a problem hiding this comment.
permute Q and K to Qwen2 ordering, then use freq_factors to correct the theta_base
I don't think we can allow adding yet another rope mode here, this part of the code is already fragile enough
|
This PR adds model support for ERNIE 4.5 family models from Baidu, including Dense (ernie4_5), MoE (ernie4_5-moe), and Vision-Language MoE (ernie4_5-vl-moe) variants. It has been verified with vision and pure text modes.
Dual MoE (Text + Vision Experts): The core innovation of ERNIE 4.5 VL MoE is its dual expert system within the same LLM backbone. MoE layers maintain two separate sets of experts — one for text tokens and one for vision tokens — dynamically routed based on input modality. Vision experts use a significantly smaller FFN intermediate size (default 512) compared to text experts, reflecting a compact representation design for visual features.
Interleaved Dense/MoE layers: Controlled by n_layer_dense_lead and n_moe_layer_step, the first few layers are Dense, and MoE layers are interleaved at a configurable step interval. A shared expert (SwiGLU FFN) is added on top of MoE output for both modalities.
ERNIE3D RoPE: A new RoPE type (GGML_ROPE_TYPE_ERNIE3D = 72) designed for multimodal use, with an interleaved 3D frequency layout encoding height/width/temporal dimensions (sections [22, 22, 20, 0]), distinct from the contiguous segmentation used by standard M-RoPE.
Vision Encoder: Standard ViT with 2D M-RoPE (no learned positional embeddings), using SwiGLU FFN in each transformer layer.
Vision Projector: A spatial + temporal resampler pipeline:
2×2 spatial patch merging (4× token reduction)
Spatial linear path (Linear → GELU → Linear → LayerNorm)
Temporal path (optional, for video frames; single images use self-concatenation)
Final MLP + RMS Norm projection to LLM embedding space