Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
{ LLM_ARCH_QWEN35, "qwen35" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PLAMO, "plamo" },
Expand Down Expand Up @@ -260,6 +261,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_QWEN3NEXT:
case LLM_ARCH_QWEN35MOE:
case LLM_ARCH_QWEN35:
return true;
default:
return false;
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ enum llm_arch {
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_QWEN35MOE,
LLM_ARCH_QWEN35,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PLAMO,
Expand Down
134 changes: 134 additions & 0 deletions src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4663,6 +4663,136 @@ ggml_cgraph * llm_build_context::build_qwen35moe() {
return gf;
}

ggml_cgraph * llm_build_context::build_qwen35() {
static constexpr int QWEN3NEXT_CHUNK_SIZE = 64;

struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false);

delta_net delta(lctx, batch);

const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);

int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);

auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {

auto Qaux = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
auto Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
auto Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Qaux, "Qaux", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
ggml_build_forward_expand(gf, Qaux);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);

Qaux = ggml_reshape_3d(ctx0, Qaux, n_embd_head * 2, n_head, n_tokens);
auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], 0));
auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], n_embd_head*ggml_element_size(Qaux)), n_embd_head*n_head, n_tokens);
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);

Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);

Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);

Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);

Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);

cb(Qcur, "Qcur_roped", il);
cb(Kcur, "Kcur_roped", il);

ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv,
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale, cb, il);
cb(attn, "attn_pregate", il);

gate = ggml_sigmoid(ctx0, gate);
cb(gate, "gate_sigmoid", il);
attn = ggml_mul(ctx0, attn, gate);
cb(attn, "attn_gated", il);

attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn);
cb(attn, "attn_output", il);

return attn;

};

ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
ggml_tensor * KQ_mask = build_inp_KQ_mask();

lctx.inp_s_seq_qnext = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 1, n_tokens);
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
ggml_set_input(lctx.inp_s_seq_qnext);

ggml_tensor * causal_mask = nullptr;
ggml_tensor * identity = nullptr;
ggml_tensor * diag_mask = nullptr;
causal_mask = ggml_tri(ctx0,
ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f),
GGML_TRI_TYPE_LOWER);
identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f));
diag_mask = ggml_add(ctx0, causal_mask, identity);
ggml_build_forward_expand(gf, causal_mask);
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, diag_mask);

ggml_tensor * cur = nullptr;

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);

if (hparams.is_recurrent(il)) {
cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb);
} else {
cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
}

if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}

cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);

cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true, false);

cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);

inpL = cur;
}

cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb);
cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);

return gf;
}

ggml_cgraph * llm_build_context::build_qwen3vl() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model, n_tokens), false);

Expand Down Expand Up @@ -9840,6 +9970,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_qwen35moe();
} break;
case LLM_ARCH_QWEN35:
{
result = llm.build_qwen35();
} break;
case LLM_ARCH_QWEN3VL:
{
result = llm.build_qwen3vl();
Expand Down
2 changes: 2 additions & 0 deletions src/llama-build-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ struct llm_build_context {

ggml_cgraph * build_qwen35moe();

ggml_cgraph * build_qwen35();

ggml_cgraph * build_phi2();

ggml_cgraph * build_phi3();
Expand Down
27 changes: 27 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,33 @@ void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN35:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);

// Load linear attention (gated delta net) parameters
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);

// Mark recurrent layers (linear attention layers)
{
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
}
}

switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_2B; break;
case 64: model.type = e_model::MODEL_27B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3VLMOE:
{
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
Expand Down
67 changes: 67 additions & 0 deletions src/llama-load-tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {

bool create_qwen35moe_tensors(const LLM_TN & tn);

bool create_qwen35_tensors(const LLM_TN & tn);

bool create_phi2_tensors(const LLM_TN & tn);

bool create_phi3_tensors(const LLM_TN & tn);
Expand Down Expand Up @@ -1465,6 +1467,69 @@ bool create_tensors_helper::create_qwen35moe_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}

bool create_tensors_helper::create_qwen35_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});

// output
{
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
if (model.output == NULL) {
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}

const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t head_v_dim = hparams.ssm_d_state;
const int64_t n_k_heads = hparams.ssm_n_group;
const int64_t n_v_heads = hparams.ssm_dt_rank;
const int64_t key_dim = head_k_dim * n_k_heads;
const int64_t value_dim = head_v_dim * n_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;

for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_split = ctx_for_layer_split(i);

auto & layer = model.layers[i];

layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
layer.ffn_norm = layer.attn_post_norm;

if (!hparams.is_recurrent(i)) {
// Attention layers
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);

// Q/K normalization for attention layers
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_alpha = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_norm = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
}

layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);

}

return use_mmap_buffer;
}

bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE

Expand Down Expand Up @@ -3402,6 +3467,8 @@ bool create_tensors_helper::create_tensors() {
use_mmap_buffer = create_qwen3next_tensors(tn); break;
case LLM_ARCH_QWEN35MOE:
use_mmap_buffer = create_qwen35moe_tensors(tn); break;
case LLM_ARCH_QWEN35:
use_mmap_buffer = create_qwen35_tensors(tn); break;
case LLM_ARCH_PHI2:
use_mmap_buffer = create_phi2_tensors(tn); break;
case LLM_ARCH_PHI3:
Expand Down
28 changes: 28 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,34 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_QWEN35,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" },
{ LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_QWEN3VL,
{
Expand Down
2 changes: 1 addition & 1 deletion src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ struct llama_model {
size_t max_nodes(int n_tokens) const {
auto n_tensors = tensors_by_name.size();
if (split_mode == LLAMA_SPLIT_MODE_GRAPH && !devices.empty()) n_tensors *= devices.size();
if (arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_QWEN35MOE) {
if (arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_QWEN35) {
return std::max<size_t>(n_tokens * 40, 32u * n_tensors);
}
//return std::max<size_t>(1024, 8*n_tensors);
Expand Down
1 change: 1 addition & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5358,6 +5358,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_QWEN3VL:
case LLM_ARCH_QWEN3VLMOE:
case LLM_ARCH_QWEN35MOE:
case LLM_ARCH_QWEN35:
return LLAMA_ROPE_TYPE_IMROPE;

// all model arches should be listed explicitly here
Expand Down