Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 131 additions & 81 deletions src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,6 @@ static inline ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur,
if (the_norm && the_norm->extra) {
auto norm = (ggml_split_tensor_t *)the_norm->extra;
GGML_ASSERT(norm->splits[id]);
//if (cur->type != GGML_TYPE_F16 && cur->type != GGML_TYPE_F32) {
// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
//}
if (is_norm) {
cur = ggml_fused_norm(ctx, cur, norm->splits[id], hparams.f_norm_eps);
} else {
Expand Down Expand Up @@ -749,6 +746,9 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (!split_u) continue;
auto cur = get_input_tensor_sm_graph(ctx, input, id);
cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, is_norm);
if (input->op != GGML_OP_REDUCE) {
cur->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
}
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
cb(cur, "ffn_up_gate", il_cb);
if (lctx.model.arch == LLM_ARCH_STEP35) {
Expand Down Expand Up @@ -1305,13 +1305,29 @@ llm_expert_gating_func_type gating_op,
(!split_up_shexp->splits[id] && !split_gate_shexp->splits[id] && !split_down_shexp->splits[id]));
if (!split_up_shexp->splits[id]) continue;
auto the_ffn_norm = ffn_norm ? ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra)->splits[id] : ffn_norm : nullptr;
auto shared_out = llm_build_ffn(ctx, lctx, the_ffn_norm, input,
auto this_input = input;
if (the_ffn_norm) {
this_input = llm_build_norm(ctx, input, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il);
}
auto shared_out = llm_build_ffn(ctx, lctx, nullptr, this_input,
split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr,
split_gate_shexp->splits[id], split_gate_b_shexp ? split_gate_b_shexp->splits[id] : nullptr, nullptr,
split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr,
nullptr, type_op_shexp, LLM_FFN_PAR, cb, il, graph, false, false,
id == id_add_routed ? routed_out : nullptr);
cb(shared_out, "ffn_shexp_out", il_cb);
if (shexp_gate) {
auto split_shexp_gate = (ggml_split_tensor_t *)shexp_gate->extra;
GGML_ASSERT(split_shexp_gate && split_shexp_gate->splits[id]);
auto gate = llm_build_lora_mm(lctx, ctx, split_shexp_gate->splits[id], this_input);
if (gate->ne[1] == 1) {
shared_out = ggml_fused_mul_unary(ctx, gate, shared_out, GGML_UNARY_OP_SIGMOID);
} else {
gate = ggml_sigmoid(ctx, gate);
shared_out = ggml_mul(ctx, shared_out, gate);
}
cb(shared_out, "ffn_shexp_gated", il_cb);
}
if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type);
}
Expand Down Expand Up @@ -1374,6 +1390,9 @@ llm_expert_gating_func_type gating_op,
int il_cb = 1000*(id + 1) + il;
auto cur = get_input_tensor_sm_graph(ctx, input, id);
cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, false);
if (cur->op != GGML_OP_REDUCE) {
cur->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
}
GGML_ASSERT(!split_gate_inp_b || split_gate_inp_b->splits[id]);
GGML_ASSERT(!split_exps_down_b || split_exps_down_b->splits[id]);
GGML_ASSERT(!split_exps_gate_b || split_exps_gate_b->splits[id]);
Expand All @@ -1399,6 +1418,18 @@ llm_expert_gating_func_type gating_op,
split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr,
nullptr, type_op_shexp, LLM_FFN_PAR, cb, il);
cb(shared_out, "ffn_shexp_out", il_cb);
if (shexp_gate) {
auto split_shexp_gate = (ggml_split_tensor_t *)shexp_gate->extra;
GGML_ASSERT(split_shexp_gate && split_shexp_gate->splits[id]);
auto gate = llm_build_lora_mm(lctx, ctx, split_shexp_gate->splits[id], cur);
if (gate->ne[1] == 1) {
shared_out = ggml_fused_mul_unary(ctx, gate, shared_out, GGML_UNARY_OP_SIGMOID);
} else {
gate = ggml_sigmoid(ctx, gate);
shared_out = ggml_mul(ctx, shared_out, gate);
}
cb(shared_out, "ffn_shexp_gated", il_cb);
}

cur = ggml_add(ctx, routed_out, shared_out);
cb(cur, "ffn_out", il_cb);
Expand Down Expand Up @@ -1748,6 +1779,38 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
return {Qcur, Kcur, Vcur};
}

std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_build_mul_mat_qkv_gated(ggml_cgraph * gf, ggml_tensor * cur,
ggml_tensor * wq, ggml_tensor * wk, ggml_tensor * wv, ggml_tensor * q_norm, ggml_tensor * k_norm, int il) const {
auto Qaux = llm_build_lora_mm(lctx, ctx0, wq, cur);
cb(Qaux, "Qaux", il);
auto Kcur = llm_build_lora_mm(lctx, ctx0, wk, cur);
cb(Kcur, "Kcur", il);
auto Vcur = llm_build_lora_mm(lctx, ctx0, wv, cur);
cb(Vcur, "Vcur", il);
ggml_build_forward_expand(gf, Qaux);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
auto row_size = ggml_row_size(Qaux->type, n_embd_head_k);
// TODO: check why CUDA performance suffers so much if we don't make these two tensors contiguous
auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head_k, Qaux->ne[0]/(2*n_embd_head_k), n_tokens, 2*row_size, Qaux->nb[1], 0));
auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head_k, Qaux->ne[0]/(2*n_embd_head_k), n_tokens, 2*row_size, Qaux->nb[1], row_size), Qaux->ne[0]/2, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, Kcur->ne[0]/n_embd_head_k, n_tokens);
if (q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
ggml_build_forward_expand(gf, Qcur);
}
if (k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
ggml_build_forward_expand(gf, Kcur);
}
gate = ggml_sigmoid(ctx0, gate);
//gate = ggml_reshape_2d(ctx0, gate, gate->ne[0]*gate->ne[1], gate->ne[2]);
cb(gate, "gate", il);
return {Qcur, Kcur, Vcur, gate};
}

std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
ggml_tensor * wqkv, ggml_tensor * bqkv,
ggml_tensor * wqk, ggml_tensor * bqk,
Expand Down Expand Up @@ -4329,59 +4392,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);

auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {

auto Qaux = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
auto Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
auto Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Qaux, "Qaux", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
ggml_build_forward_expand(gf, Qaux);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);

Qaux = ggml_reshape_3d(ctx0, Qaux, n_embd_head * 2, n_head, n_tokens);
auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], 0));
auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], n_embd_head*ggml_element_size(Qaux)), n_embd_head*n_head, n_tokens);
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);

Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);

Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);

Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);

Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);

cb(Qcur, "Qcur_roped", il);
cb(Kcur, "Kcur_roped", il);

ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv,
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale, cb, il);
cb(attn, "attn_pregate", il);

gate = ggml_sigmoid(ctx0, gate);
cb(gate, "gate_sigmoid", il);
attn = ggml_mul(ctx0, attn, gate);
cb(attn, "attn_gated", il);

attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn);
cb(attn, "attn_output", il);

return attn;

};

ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
Expand All @@ -4403,6 +4413,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_build_forward_expand(gf, identity);
ggml_build_forward_expand(gf, diag_mask);

float KQ_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;

ggml_tensor * cur = nullptr;

for (int il = 0; il < n_layer; ++il) {
Expand All @@ -4422,23 +4434,29 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr);
}

cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);

if (hparams.is_recurrent(il)) {
if (inpL->op == GGML_OP_REDUCE && inpL->src[model.default_layer_device[il]]) {
inpL->view_src = inpL->src[model.default_layer_device[il]];
//printf("Using reduce result on device %d\n", model.default_layer_device[il]);
//inpL = inpL->src[model.default_layer_device[il]];
}
auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[model.default_layer_device[il]] : model.layers[il].attn_norm;
cur = llm_build_norm(ctx0, inpL, hparams, norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
} else {
cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
}

if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
//cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr,
KQ_mask, nullptr, nullptr, KQ_scale, 0.0f, 0, il, true, false, true, false, false);
}

cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);

if (!model.layers[il].ffn_gate_inp) {
// dense FFN
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
Expand Down Expand Up @@ -9927,11 +9945,19 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr;
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
split_wq, bq ? bq->splits[id] : nullptr,
split_wk, bk ? bk->splits[id] : nullptr,
split_wv, bv ? bv->splits[id] : nullptr,
the_q_norm, the_k_norm, f_attn_scale, il, add_graph_split);
ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr;
if (model.arch == LLM_ARCH_QWEN3NEXT) {
auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, split_wq, split_wk, split_wv,
the_q_norm, the_k_norm, il);
Qcur = Q; Kcur = K; Vcur = V; gate = G;
} else {
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
split_wq, bq ? bq->splits[id] : nullptr,
split_wk, bk ? bk->splits[id] : nullptr,
split_wv, bv ? bv->splits[id] : nullptr,
the_q_norm, the_k_norm, f_attn_scale, il, add_graph_split);
Qcur = Q; Kcur = K; Vcur = V;
}
auto rope_factors = rope_factors_in;
if (rope_factors) {
GGML_ASSERT(rope_factors->extra);
Expand Down Expand Up @@ -10058,6 +10084,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens

cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
cb(cur, "flash_attn_reshaped", il_cb);
if (gate) {
cur = ggml_mul(ctx0, cur, gate);
cb(cur, "qkv_gated", il_cb);
}

if (inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
Expand Down Expand Up @@ -10106,11 +10136,19 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
}
auto input_normed = cur;

auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il);
ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr;
if (model.arch == LLM_ARCH_QWEN3NEXT) {
auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, model.layers[il].wq, model.layers[il].wk, model.layers[il].wv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, il);
Qcur = Q; Kcur = K; Vcur = V; gate = G;
} else {
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il);
Qcur = Q; Kcur = K; Vcur = V;
}

if (do_rope) {
if (is_multi) {
Expand Down Expand Up @@ -10157,9 +10195,21 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
}
cb(cur, "attn_out", il);
} else {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
if (gate) {
cur = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
cur = ggml_mul(ctx0, cur, gate);
cb(cur, "qkv_gated", il);
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
if (model.layers[il].bo) {
cur = ggml_add(ctx0, cur, model.layers[il].bo);
}
cb(cur, "attn_out", il);
} else {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
}
}

if (inp_out_ids) {
Expand Down
3 changes: 3 additions & 0 deletions src/llama-build-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ struct llm_build_context {
ggml_tensor * wv, ggml_tensor * bv,
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il, bool add_graph_split = false) const;

std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_mul_mat_qkv_gated(ggml_cgraph * gf, ggml_tensor * cur,
ggml_tensor * wq, ggml_tensor * wk, ggml_tensor * wv, ggml_tensor * q_norm, ggml_tensor * k_norm, int il) const;

ggml_cgraph * build_llama();

ggml_cgraph * build_mistral3();
Expand Down
4 changes: 4 additions & 0 deletions src/llama-delta-net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
}
cb(beta, "beta", il);
cb(alpha, "alpha", il);
ggml_build_forward_expand(gf, beta);
ggml_build_forward_expand(gf, alpha);

ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
Expand All @@ -529,6 +531,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
}
if (reset_state_local) {
state_f32 = ggml_scale(ctx0, state_f32, 0.0f);
cb(state_f32, "state_reset", il);
}

ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0);
Expand All @@ -539,6 +542,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1);
cb(conv_states, "conv_states", il);
cb(state, "state_predelta", il);
ggml_build_forward_expand(gf, state);

ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext);
cb(conv_output_raw, "conv_output_raw", il);
Expand Down
Loading