Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4102,39 +4102,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
# process the experts separately
name = name.replace("language_model.", "") # InternVL

# handle aggregated expert tensors
# GGUF stores dimensions reversed from PyTorch, so:
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
# handle pre-packed expert tensors (e.g. Qwen3.5 MoE, Qwen3Next)
# HF stores these using nn.Linear convention: [n_expert, out_features, in_features]
# This matches the individual expert stacking path below (which stacks
# per-expert [out, in] weights into [n_expert, out, in]), so no permute is needed.
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
mapped = f"{name}.weight" if not name.endswith(".weight") else name
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
permuted = data_torch.permute(0, 2, 1).contiguous()
yield from super().modify_tensors(permuted, mapped, bid)
# HF: [n_expert, n_embd, n_ff] → GGML: {n_ff, n_embd, n_expert} ✓
yield from super().modify_tensors(data_torch, mapped, bid)
return

if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
split_dim = data_torch.shape[-1] // 2
gate = data_torch[..., :split_dim].contiguous()
up = data_torch[..., split_dim:].contiguous()
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
base_name = name.removesuffix(".weight")
base = base_name.rsplit('.', 1)[0]
mapped_gate = f"{base}.gate_proj.weight"
mapped_up = f"{base}.up_proj.weight"
perm_gate = gate.permute(0, 2, 1).contiguous()
perm_up = up.permute(0, 2, 1).contiguous()
yield from super().modify_tensors(perm_gate, mapped_gate, bid)
yield from super().modify_tensors(perm_up, mapped_up, bid)
# HF: [n_expert, 2*n_ff, n_embd] → split on dim=1
n_ff = data_torch.shape[1] // 2
gate = data_torch[:, :n_ff, :].contiguous()
up = data_torch[:, n_ff:, :].contiguous()
# gate/up: [n_expert, n_ff, n_embd] → GGML: {n_embd, n_ff, n_expert} ✓
base_name = name.removesuffix(".weight").removesuffix(".gate_up_proj")
mapped_gate = f"{base_name}.gate_proj.weight"
mapped_up = f"{base_name}.up_proj.weight"
yield from super().modify_tensors(gate, mapped_gate, bid)
yield from super().modify_tensors(up, mapped_up, bid)
return

if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
Expand Down Expand Up @@ -4344,6 +4332,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3_5ForCausalLM", "Qwen3_5TextForCausalLM")
class Qwen3_5Model(Qwen3NextModel):
model_arch = gguf.MODEL_ARCH.QWEN3_5

# Stores whichever of in_proj_a/in_proj_b is seen first, keyed by layer
_pending_ba: dict[int | None, tuple[str, Tensor]] = {}

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Handle split in_proj_b + in_proj_a → concatenated SSM_BETA_ALPHA
# safetensors sorts alphabetically so in_proj_a arrives before in_proj_b
if "in_proj_a.weight" in name or "in_proj_b.weight" in name:
which = "a" if "in_proj_a" in name else "b"
if bid not in self._pending_ba:
self._pending_ba[bid] = (which, data_torch)
return
prev_which, prev_tensor = self._pending_ba.pop(bid)
assert prev_which != which, f"duplicate in_proj_{which} for layer {bid}"
b_tensor = prev_tensor if prev_which == "b" else data_torch
a_tensor = prev_tensor if prev_which == "a" else data_torch
ba_combined = torch.cat([b_tensor, a_tensor], dim=0)
yield (self.format_tensor_name(gguf.MODEL_TENSOR.SSM_BETA_ALPHA, bid, ".weight"), ba_combined)
return
else:
# Qwen3Next uses .qkvz tensor, so we use the super to get the other functionalities
# (norm correction, A_log to A etc.) for free
# Qwen2Moe already does the gate_up conversion properly, just use that
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3_5MoeForCausalLM", "Qwen3_5MoeTextForCausalLM")
class Qwen3_5MoeModel(Qwen3_5Model):
model_arch = gguf.MODEL_ARCH.QWEN3_5_MOE


@ModelBase.register("RND1")
class RND1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.RND1
Expand Down
59 changes: 59 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class MODEL_ARCH(IntEnum):
QWEN3 = auto()
QWEN3MOE = auto()
QWEN3NEXT = auto()
QWEN3_5 = auto()
QWEN3_5_MOE = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
PHI2 = auto()
Expand Down Expand Up @@ -812,6 +814,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3_5: "qwen3_5",
MODEL_ARCH.QWEN3_5_MOE: "qwen3_5moe",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.PHI2: "phi2",
Expand Down Expand Up @@ -1784,6 +1788,61 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_BETA_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.QWEN3_5: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_BETA_ALPHA,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.QWEN3_5_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_BETA_ALPHA,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.QWEN3VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
6 changes: 4 additions & 2 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.qkv", # neobert
"layers.{bid}.attn.Wqkv", # modern-bert
"model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
"model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5
),

# Attention query
Expand Down Expand Up @@ -358,8 +359,9 @@ class TensorNameMap:
),

MODEL_TENSOR.ATTN_GATE: (
"model.layers.{bid}.self_attn.gate_proj", # afmoe
"model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
"model.layers.{bid}.self_attn.gate_proj", # afmoe
"model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
"model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5
),

# Feed-forward norm
Expand Down
3 changes: 3 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ add_library(llama
models/deci.cpp
models/deepseek.cpp
models/deepseek2.cpp
models/delta.cpp
models/dots1.cpp
models/dream.cpp
models/ernie4-5-moe.cpp
Expand Down Expand Up @@ -122,6 +123,8 @@ add_library(llama
models/qwen3vl-moe.cpp
models/qwen3moe.cpp
models/qwen3next.cpp
models/qwen3-5.cpp
models/qwen3-5moe.cpp
models/refact.cpp
models/rnd1.cpp
models/rwkv6-base.cpp
Expand Down
61 changes: 61 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3_5, "qwen3_5" },
{ LLM_ARCH_QWEN3_5_MOE, "qwen3_5moe" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
Expand Down Expand Up @@ -985,6 +987,63 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
};
case LLM_ARCH_QWEN3_5:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_SSM_A_NOSCAN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_BETA_ALPHA,
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
};
case LLM_ARCH_QWEN3_5_MOE:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_SSM_A_NOSCAN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_BETA_ALPHA,
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
};
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_HUNYUAN_DENSE:
Expand Down Expand Up @@ -2674,6 +2733,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_NEMOTRON_H:
case LLM_ARCH_NEMOTRON_H_MOE:
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_QWEN3_5:
case LLM_ARCH_QWEN3_5_MOE:
case LLM_ARCH_KIMI_LINEAR:
return true;
default:
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ enum llm_arch {
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3_5,
LLM_ARCH_QWEN3_5_MOE,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
Expand Down
2 changes: 1 addition & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2013,7 +2013,7 @@ void llama_context::output_reorder() {
//

uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) {
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN3_5 || model.arch == LLM_ARCH_QWEN3_5_MOE || model.arch == LLM_ARCH_KIMI_LINEAR) {
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
}
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
Expand Down
Loading
Loading