Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"Ernie4_5_ForCausalLM": "ernie",
"Ernie4_5_MoeForCausalLM": "ernie",
"EuroBertModel": "bert",
"Exaone4_5_ForConditionalGeneration": "exaone",
"Exaone4ForCausalLM": "exaone",
"ExaoneForCausalLM": "exaone",
"ExaoneMoEForCausalLM": "exaone",
Expand Down Expand Up @@ -236,6 +237,7 @@
"CogVLMForCausalLM": "cogvlm",
"DeepseekOCRForCausalLM": "deepseek",
"DotsOCRForCausalLM": "dotsocr",
"Exaone4_5_ForConditionalGeneration": "exaone",
"Gemma3ForConditionalGeneration": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
Expand Down
2 changes: 1 addition & 1 deletion conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,7 +2455,7 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
# Step3-VL keeps text config under text_config but uses a custom top-level architecture.
# For text conversion we route to a dedicated text-only class.
# TODO: refactor this later to avoid adding exception here
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM"):
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM", "Exaone4_5_ForConditionalGeneration"):
return arch

# if "architectures" is found in the sub-config, use that instead
Expand Down
98 changes: 96 additions & 2 deletions conversion/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import math

from pathlib import Path
from typing import Iterable, TYPE_CHECKING
from typing import Callable, Iterable, TYPE_CHECKING

import torch

if TYPE_CHECKING:
from torch import Tensor

from .base import ModelBase, TextModel, gguf
from .base import MmprojModel, ModelBase, TextModel, gguf
from .qwenvl import Qwen2VLVisionModel


@ModelBase.register("ExaoneForCausalLM")
Expand Down Expand Up @@ -208,3 +209,96 @@ def prepare_tensors(self):
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

@ModelBase.register("Exaone4_5_ForConditionalGeneration")
class Exaone4_5_TextModel(Exaone4Model):
"""Text tower of EXAONE 4.5; Tensors match EXAONE4"""

model_arch = gguf.MODEL_ARCH.EXAONE4

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0)
if n_nextn > 0:
self.block_count = self.hparams["num_hidden_layers"] + n_nextn
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_gguf_parameters(self):
super().set_gguf_parameters()
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0)
if n_nextn > 0:
self.gguf_writer.add_nextn_predict_layers(n_nextn)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp."):
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0)
if n_nextn <= 0:
return
nh = self.hparams["num_hidden_layers"]
if ".layers." in name:
share = self.hparams.get("mtp_share_layers", False)
mtp_bid = bid if bid is not None else 0
if share:
for k in range(n_nextn):
nn = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{nh + k}")
yield from super().modify_tensors(data_torch, nn, nh + k)
return
name = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{mtp_bid + nh}")
else:
remapper = {
"mtp.fc": gguf.MODEL_TENSOR.NEXTN_EH_PROJ,
"mtp.pre_fc_norm_embedding": gguf.MODEL_TENSOR.NEXTN_ENORM,
"mtp.pre_fc_norm_hidden": gguf.MODEL_TENSOR.NEXTN_HNORM,
"mtp.norm": gguf.MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
}
_n = Path(name)
key = _n.stem
if key not in remapper:
return
for bid_mtp in range(nh, self.block_count):
mapped_name = self.format_tensor_name(remapper[key], bid_mtp, suffix=_n.suffix)
yield from ModelBase.modify_tensors(self, data_torch, mapped_name, bid_mtp)
return

yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Exaone4_5_ForConditionalGeneration")
class Exaone4_5VisionModel(Qwen2VLVisionModel):
"""Vision tower for EXAONE 4.5; Qwen2-VL-style ViT (GQA) + patch merger"""

@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
name = name.replace("model.visual.", "visual.", 1)
return super().filter_tensors((name, gen))

def set_gguf_parameters(self):
MmprojModel.set_gguf_parameters(self)
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.EXAONE4_5)
self.gguf_writer.add_vision_use_silu(True)
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
num_kv_head = self.find_vparam(["num_key_value_heads"], optional=True)
if num_kv_head is not None:
self.gguf_writer.add_vision_head_count_kv(num_kv_head)
eps = hparams.get("rms_norm_eps", self.global_config.get("rms_norm_eps", 1e-6))
self.gguf_writer.add_vision_attention_layernorm_eps(eps)
if (window_size := hparams.get("window_size")) is not None:
self.gguf_writer.add_vision_window_size(window_size)
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
if fullatt_block_indexes:
n_wa_pattern = fullatt_block_indexes[0] + 1
for i in range(1, len(fullatt_block_indexes)):
if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern:
raise ValueError(f"Invalid EXAONE4.5 fullatt_block_indexes: {fullatt_block_indexes}")
self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if ".qkv." in name:
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
return

yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)
34 changes: 34 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ class MODEL_ARCH(IntEnum):
EXAONE = auto()
EXAONE4 = auto()
EXAONE_MOE = auto()
EXAONE4_5 = auto()
Comment thread
CISC marked this conversation as resolved.
Outdated
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_HYBRID = auto()
Expand Down Expand Up @@ -981,6 +982,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.EXAONE_MOE: "exaone-moe",
MODEL_ARCH.EXAONE4_5: "exaone4_5",
Comment thread
CISC marked this conversation as resolved.
Outdated
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
Expand Down Expand Up @@ -3244,6 +3246,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.EXAONE_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down Expand Up @@ -3277,6 +3286,30 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.EXAONE4_5: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
Comment thread
CISC marked this conversation as resolved.
Outdated
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down Expand Up @@ -4235,6 +4268,7 @@ class VisionProjectorType:
LLAMA4 = "llama4"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
EXAONE4_5 = "exaone4_5"
QWEN3VL = "qwen3vl_merger"
STEP3VL = "step3vl"
ULTRAVOX = "ultravox"
Expand Down
47 changes: 35 additions & 12 deletions src/models/exaone4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) {

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
Comment thread
nuxlear marked this conversation as resolved.
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;

switch (hparams.n_layer) {
case 30: type = LLM_TYPE_1_2B; break;
Expand All @@ -38,21 +41,37 @@ void llama_model_exaone4::load_arch_tensors(llama_model_loader &) {
}

for (int i = 0; i < n_layer; ++i) {
const bool is_nextn = hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers;
int flags = 0;
if (is_nextn) {
// NextN/MTP layers are preserved in GGUF but are not executed yet.
flags |= TENSOR_SKIP;
}

auto & layer = layers[i];

create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, flags);

if (!is_nextn) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
}

layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);

layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags);

layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
if (is_nextn) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED);
}
}
}

Expand Down Expand Up @@ -90,7 +109,11 @@ llama_model_exaone4::graph<iswa>::graph(const llama_model & model, const llm_gra
}
ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
// MTP / NextN tail blocks are loaded for compatibility but not executed (same as exaone-moe).
const int n_layer_main = int(n_layer) - int(hparams.nextn_predict_layers);
GGML_ASSERT(n_layer_main > 0);

for (int il = 0; il < n_layer_main; ++il) {
ggml_tensor * inpSA = inpL;

// use RoPE for SWA layers or non-SWA models
Expand Down Expand Up @@ -126,7 +149,7 @@ llama_model_exaone4::graph<iswa>::graph(const llama_model & model, const llm_gra
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_layer_main - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(mtmd
models/cogvlm.cpp
models/conformer.cpp
models/dotsocr.cpp
models/exaone4_5.cpp
models/gemma4a.cpp
models/gemma4v.cpp
models/glm4v.cpp
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ enum projector_type {
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_NEMOTRON_V2_VL,
PROJECTOR_TYPE_HUNYUANOCR,
PROJECTOR_TYPE_EXAONE4_5,
PROJECTOR_TYPE_HUNYUANVL,
PROJECTOR_TYPE_MINICPMV4_6,
PROJECTOR_TYPE_GRANITE_SPEECH,
Expand Down Expand Up @@ -394,6 +395,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
{ PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"},
{ PROJECTOR_TYPE_EXAONE4_5, "exaone4_5"},
{ PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"},
{ PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"},
{ PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"},
Expand Down
Loading