Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
27baad4
kimi linear model implementation
ymcki Dec 2, 2025
84f822c
kimi linear convert_hf_to_gguf
ymcki Dec 2, 2025
57cca52
kimi linear constants.py tensor_mapping.py
ymcki Dec 2, 2025
6167f39
Kimi Linear ggml.h
ymcki Dec 2, 2025
26a6553
kimi linear ggml-cpu
ymcki Dec 2, 2025
bf42bc0
Kimi Linear ggml-cuda
ymcki Dec 2, 2025
d73d3e5
Kimi Linear ggml.c
ymcki Dec 2, 2025
e308026
kimi linear src/llama
ymcki Dec 2, 2025
139548d
remove "const int64_t n_seq_tokens = q->ne[2];" to get rid of unused …
ymcki Dec 2, 2025
83d328d
remove type mismatch warning
ymcki Dec 2, 2025
772ca88
read MoE params
ymcki Dec 2, 2025
9f1265f
removed some hard coded code
ymcki Dec 5, 2025
a0269af
removed all hard code
ymcki Dec 6, 2025
ef5bc30
use DeepseekV2 tokenizer
ymcki Dec 14, 2025
ae9771d
removed unnecessary internal methods called by the old set_vocab of K…
ymcki Dec 18, 2025
f9a11d7
rewrite get_vocab for KimiLinear. Removed all kda_scan code
ymcki Dec 18, 2025
776294c
removed all traces of kda_scan
ymcki Dec 18, 2025
f67a42d
reduce OP count by 1 due to removal of kda_scan
ymcki Dec 18, 2025
f85e5c7
Move KIMI_LINEAR to llm_arch_is_hybrid to enable KV cache
ymcki Jan 2, 2026
8bd617e
set n_embd_head_k/v to ensure kv cache works
ymcki Jan 3, 2026
a4020d8
don't quantize conv1d of Kimi Linear
ymcki Jan 3, 2026
66c0c5d
Kimi Linear backend agnostic
ymcki Jan 5, 2026
aba181e
removed LOG_INFO
ymcki Jan 5, 2026
cfed14e
naive chunking form implemented
ymcki Jan 6, 2026
e3542ff
fixed some comments
ymcki Jan 6, 2026
67bee56
add Kimi-K2 specific tokens to be recognized as EOG
ymcki Jan 6, 2026
30d883c
sync fork from b7240 to b7243
ymcki Jan 6, 2026
40f6118
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 7, 2026
1099cbf
build_kda_autoregressive is implemented to replace build_kda_recurren…
ymcki Jan 7, 2026
f99913d
replaced Akk and Aqk with mul_mat and clamp
ymcki Jan 8, 2026
6977ddb
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 9, 2026
6150bb7
no clamp version
ymcki Jan 9, 2026
d26fe50
Moved Aqk computation out of the loop
ymcki Jan 10, 2026
dce064c
fixed typo and split wkv_b into wk_b and wv_b
ymcki Jan 10, 2026
426a82d
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 11, 2026
b9360c7
MLA KV cache support
ymcki Jan 11, 2026
5f2b8dd
Merge branch 'master' of github.com:ymcki/llama.cpp into Kimi-Linear
ymcki Jan 11, 2026
10be797
Merge branch 'Kimi-Linear' of github.com:ymcki/llama.cpp into Kimi-Li…
ymcki Jan 11, 2026
6ae66fc
fix trailing spaces
ymcki Jan 11, 2026
93afbed
moved const llama_model & model; around to follow qwen3next format an…
ymcki Jan 11, 2026
59182f5
fix trailing whitespace
ymcki Jan 11, 2026
58d1ee5
removed traling whitespaces in empty line + make sure indentation is …
ymcki Jan 11, 2026
4f6ef2c
try to make lint happy
ymcki Jan 11, 2026
719d374
remove blank lines to make lint happy
ymcki Jan 11, 2026
ac85cb1
removed at least blank line containing white space
ymcki Jan 12, 2026
4faf26c
fixed flake8 complaints locally
ymcki Jan 12, 2026
22bc582
return ggml_tensor * pair in kda_autoregressive and kda_chunking as i…
ymcki Jan 12, 2026
217e7ce
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 13, 2026
6ba78d1
removed Kimi-Linear specific change that causes failure at server-win…
ymcki Jan 13, 2026
fe9d248
removed private: from kimi_linear to make build checks happy
ymcki Jan 13, 2026
18ae7f4
removed unnecessary ggml_cont before ggml_reshape
ymcki Jan 13, 2026
2882915
created static function causal_conv1d to abtract similar code for q/k/v
ymcki Jan 14, 2026
c163dff
sync fork and comment fixing in kimi-linear.cpp
ymcki Jan 14, 2026
0aea18e
merged dt_bias to SSM_DT. Do -exp(log_A) in convert_hf_to_gguf.py.
ymcki Jan 16, 2026
f3d118d
reverted to original
ymcki Jan 16, 2026
c26c121
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 16, 2026
e87ac9b
Merge branch 'master' of github.com:ymcki/llama.cpp into Kimi-Linear
ymcki Jan 16, 2026
0298731
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 21, 2026
e55caf5
Merge branch 'master' of github.com:ymcki/llama.cpp into Kimi-Linear
ymcki Jan 21, 2026
560190a
fixed find_hparam calls. Fixed e_score_correction_bias to use bias in…
ymcki Jan 21, 2026
a8147a1
Merge branch 'Kimi-Linear' of github.com:ymcki/llama.cpp into Kimi-Li…
ymcki Jan 21, 2026
ae8d710
remove DT_B from constants.py. remove one comment line in llama-model…
ymcki Jan 21, 2026
38c6f5e
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 25, 2026
92f4949
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 26, 2026
7fb54dd
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 26, 2026
bb02b5d
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Jan 26, 2026
f1525b3
new class llm_graph_input_mem_hybrid_k to get around the new MLA chan…
ymcki Jan 27, 2026
0de4680
remove ssm_o_norm_b
ymcki Jan 27, 2026
0444a4f
remove ssm_o_norm_b
ymcki Jan 27, 2026
a6b2c45
changed hparams.kda_head_dim to hparams.n_embd_head_kda. added TODO c…
ymcki Jan 29, 2026
6216273
removed all ggml_cont b4 ggml_reshape_4d
ymcki Jan 29, 2026
005c340
Whitespace
pwilkin Jan 30, 2026
aaf05bd
replaced all hparams.get with find_hparams
ymcki Jan 31, 2026
2a62df6
Merge branch 'Kimi-Linear' of github.com:ymcki/llama.cpp into Kimi-Li…
ymcki Jan 31, 2026
2c8cd84
added new names for n_experts, n_experts_used and score_func in TextM…
ymcki Feb 1, 2026
11282a0
use is_mla to switch between different mem_hybrid types
ymcki Feb 1, 2026
4bb4286
fixed logical errors in convert_hf_to_gguf.py pointed out by CISC
ymcki Feb 3, 2026
07f9979
Merge branch 'ggml-org:master' into Kimi-Linear
ymcki Feb 3, 2026
efaea45
removed if else for required parameters kv_lora_rank and qk_rope_head…
ymcki Feb 3, 2026
000fded
add back ggml_cont for Vcur
ymcki Feb 3, 2026
8ec5b08
minor changes
ymcki Feb 3, 2026
82215a0
removed extra line in llama-vocab.cpp. Added back the comment in llam…
ymcki Feb 3, 2026
a82103e
f16 gguf cannot run without context length
ymcki Feb 4, 2026
6456393
made a mistake of adding back n_ctx parsing
ymcki Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 222 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,10 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
# Kimi KDA conv weights should be F32
gguf.MODEL_TENSOR.SSM_CONV1D_Q,
gguf.MODEL_TENSOR.SSM_CONV1D_K,
gguf.MODEL_TENSOR.SSM_CONV1D_V,
)
)
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
Expand Down Expand Up @@ -903,10 +907,10 @@ def set_gguf_parameters(self):
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
if (n_experts := self.hparams.get("num_local_experts")) is not None:
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
self.gguf_writer.add_expert_count(n_experts)
logger.info(f"gguf: expert count = {n_experts}")
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
logger.info(f"gguf: experts used count = {n_experts_used}")
if (n_expert_groups := self.hparams.get("n_group")) is not None:
Expand All @@ -916,7 +920,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_expert_group_used_count(n_group_used)
logger.info(f"gguf: expert groups used count = {n_group_used}")

if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation_func"], optional=True)) is not None:
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
Expand Down Expand Up @@ -5013,6 +5017,221 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_factor(1.0)


@ModelBase.register("KimiLinearModel", "KimiLinearForCausalLM")
class KimiLinearModel(TextModel):
"""Kimi-Linear model with hybrid MLA+KDA architecture"""
model_arch = gguf.MODEL_ARCH.KIMI_LINEAR

_experts: list[dict[str, Tensor]] | None = None

def set_vocab(self):
try:
self._set_vocab_gpt2()
return
except Exception:
pass

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
tokpre = self.get_vocab_base_pre(tokenizer)

if tokpre == "kimi-k2":
# Build merges list using the approach similar to HunYuanMoE
merges = []
vocab = {}
mergeable_ranks = tokenizer.model._mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
# Build token list
vocab_size = self.hparams["vocab_size"]
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
# override eos id in config.json with tiktoken eos id
self.gguf_writer.add_eos_token_id(tokenizer.eos_id)
else:
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")

def set_gguf_parameters(self):
# note: To enable MLA KV cache, attention needs to be converted into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1

super().set_gguf_parameters()
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])

# KDA & MLA params
# Get ssm_d_conv from linear_attn_config.short_conv_kernel_size or ssm_d_conv
linear_attn_config = self.hparams["linear_attn_config"]
# n_head == 0 for KDA layers, n_head > 0 for MLA layers
# full_attention_layers list will be used to distingush layer type
_num_kv_heads = list()
_full_attn_layers = linear_attn_config["full_attn_layers"]
for il in range(self.hparams["num_hidden_layers"]):
if il + 1 in _full_attn_layers:
_num_kv_heads.append(self.hparams["num_key_value_heads"])
else:
_num_kv_heads.append(0)
assert len(_num_kv_heads) == self.hparams["num_hidden_layers"]
self.gguf_writer.add_head_count_kv(_num_kv_heads)

if (ssm_d_conv := linear_attn_config.get("short_conv_kernel_size")) is not None:
self.gguf_writer.add_ssm_conv_kernel(ssm_d_conv)
if (kda_head_dim := linear_attn_config.get("head_dim")) is not None:
self.gguf_writer.add_kda_head_dim(kda_head_dim)

# MLA params - use add_* methods that handle arch substitution
# Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv)
if (q_lora_rank := self.find_hparam(["q_lora_rank", "n_lora_q"], optional=True)) is not None:
self.gguf_writer.add_q_lora_rank(q_lora_rank)
# To enable MLA KV cache, MLA needs to be converted into MQA with larger heads, then decompresses to MHA
kv_lora_rank = self.find_hparam(["kv_lora_rank", "n_lora_kv"], optional=False)
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)

# MLA head dimensions
# Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim
qk_nope_head_dim = self.hparams.get("qk_nope_head_dim")
# Rotation - use qk_rope_head_dim for Kimi
qk_rope_head_dim = self.find_hparam(["qk_rope_head_dim", "n_rot"], optional=False)
self.gguf_writer.add_rope_dimension_count(qk_rope_head_dim)
self.gguf_writer.add_key_length(kv_lora_rank + qk_rope_head_dim)
v_head_dim = self.hparams.get("v_head_dim")

# Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
if (n_embd_head_k_mla := self.find_hparam(["n_embd_head_k_mla"], optional=True)) is not None:
self.gguf_writer.add_key_length_mla(n_embd_head_k_mla)
elif qk_nope_head_dim is not None:
n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
self.gguf_writer.add_key_length_mla(n_embd_head_k_mla)

# n_embd_head_v_mla = v_head_dim
if (n_embd_head_v_mla := self.hparams.get("n_embd_head_v_mla")) is not None:
self.gguf_writer.add_value_length_mla(n_embd_head_v_mla)
elif v_head_dim is not None:
self.gguf_writer.add_value_length_mla(v_head_dim)

# moe_intermediate_size (1024 for Kimi)
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
# num_shared_experts (1 for Kimi)
self.gguf_writer.add_expert_shared_count(self.hparams["num_shared_experts"])
# first_k_dense_replace (1 for Kimi - first layer uses dense MLP)
self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"])
# Routed scaling factor (expert_weights_scale = 2.446 for Kimi)
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
logger.info(f"Processing {name}: shape before = {tuple(data_torch.shape)}")

# Handle KDA conv1d weights
# HuggingFace/vLLM stores as [d_inner, d_conv] (2D), memory layout: conv_step changes fastest
# llama.cpp expects ggml ne = [d_conv, 1, d_inner, 1], memory layout: ne[0]=d_conv changes fastest
# GGUF reverses numpy shape when writing, so numpy (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1]
# Memory layouts match: both have conv_step (d_conv) changing fastest
if name.endswith((".q_conv1d.weight", ".k_conv1d.weight", ".v_conv1d.weight")):
# HF shape: [d_inner, d_conv] e.g. [4096, 4]
# Target numpy shape: (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1]
if data_torch.ndim == 2:
d_inner, d_conv = data_torch.shape
# Reshape to (1, d_inner, 1, d_conv) - memory layout preserved (d_conv fastest)
data_torch = data_torch.reshape(1, d_inner, 1, d_conv)
logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]")
elif data_torch.ndim == 3:
# Already 3D [d_inner, 1, d_conv] from unsqueeze
d_inner, _, d_conv = data_torch.shape
data_torch = data_torch.reshape(1, d_inner, 1, d_conv)
logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, 1, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]")

# Kimi specific bias
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")

# Handle A_log: iHF stores as [1, 1, num_heads, 1]
# llama.cpp expects ggml ne = [1, num_heads, 1, 1]
# GGUF reverses numpy shape: numpy (1, 1, num_heads, 1) -> ggml ne = [1, num_heads, 1, 1]
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
logger.info("Changed dt_bias to dt_proj.bias")

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False)
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
# w1: gate, w2: down, w3: up
for wid, tname in [("w1", gguf.MODEL_TENSOR.FFN_GATE_EXP),
("w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP),
("w3", gguf.MODEL_TENSOR.FFN_UP_EXP)]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
new_name = self.format_tensor_name(tname, bid)
yield from super().modify_tensors(data_torch, new_name, bid)
return

# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")
n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.find_hparam(["n_embd_head_v_mla", "v_head_dim"], optional=False)
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
logger.info("Split kv_b n_head_kv %d\n" % n_head_kv)
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)
yield from super().modify_tensors(k_b, name_kb, bid)
yield from super().modify_tensors(v_b, name_vb, bid)
return

yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("InternLM2ForCausalLM")
class InternLM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.INTERNLM2
Expand Down
65 changes: 65 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ class SSM:
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"

class KDA:
HEAD_DIM = "{arch}.kda.head_dim"

class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"

Expand Down Expand Up @@ -461,6 +464,7 @@ class MODEL_ARCH(IntEnum):
MIMO2 = auto()
LLAMA_EMBED = auto()
MAINCODER = auto()
KIMI_LINEAR = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -551,6 +555,14 @@ class MODEL_TENSOR(IntEnum):
SSM_NORM = auto()
SSM_OUT = auto()
SSM_BETA_ALPHA = auto() # qwen3next
SSM_CONV1D_Q = auto() # Kimi Linear
SSM_CONV1D_K = auto() # Kimi Linear
SSM_CONV1D_V = auto() # Kimi Linear
SSM_F_A = auto() # Kimi Linear
SSM_F_B = auto() # Kimi Linear
SSM_BETA = auto() # Kimi Linear
SSM_G_A = auto() # Kimi Linear
SSM_G_B = auto() # Kimi Linear
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
Expand Down Expand Up @@ -882,6 +894,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MIMO2: "mimo2",
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
MODEL_ARCH.MAINCODER: "maincoder",
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -969,6 +982,14 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear
MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear
MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear
MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear
MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
Expand Down Expand Up @@ -3379,6 +3400,47 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.KIMI_LINEAR: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.SSM_CONV1D_Q,
MODEL_TENSOR.SSM_CONV1D_K,
MODEL_TENSOR.SSM_CONV1D_V,
MODEL_TENSOR.SSM_F_A,
MODEL_TENSOR.SSM_F_B,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_G_A,
MODEL_TENSOR.SSM_G_B,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO
}

Expand Down Expand Up @@ -3706,6 +3768,9 @@ class VisionProjectorType:
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS

# KDA
KEY_KDA_HEAD_DIM = Keys.KDA.HEAD_DIM

# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,9 @@ def add_ssm_group_count(self, value: int) -> None:
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

def add_kda_head_dim(self, value: int) -> None:
self.add_uint32(Keys.KDA.HEAD_DIM.format(arch=self.arch), value)

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

Expand Down
Loading
Loading