@@ -1505,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15051505 default: type = LLM_TYPE_UNKNOWN;
15061506 }
15071507 } break;
1508+ case LLM_ARCH_ERNIE4_5:
1509+ {
1510+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1511+ switch (hparams.n_layer) {
1512+ case 18: type = LLM_TYPE_0_3B; break;
1513+ default: type = LLM_TYPE_UNKNOWN;
1514+ }
1515+ } break;
15081516 default: throw std::runtime_error("unsupported model architecture");
15091517 }
15101518
@@ -4345,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
43454353
43464354 layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
43474355
4356+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4357+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4358+ }
4359+ } break;
4360+ case LLM_ARCH_ERNIE4_5:
4361+ {
4362+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4363+
4364+ // output
4365+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4366+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4367+ // if output is NULL, init from the input tok embed
4368+ if (output == NULL) {
4369+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4370+ }
4371+
4372+ for (int i = 0; i < n_layer; ++i) {
4373+ auto & layer = layers[i];
4374+
4375+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4376+
4377+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4378+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
4379+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
4380+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4381+
4382+ // optional bias tensors
4383+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4384+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4385+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
4386+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4387+
4388+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4389+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
43484390 layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
43494391 layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
43504392 }
@@ -14126,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
1412614168 }
1412714169};
1412814170
14171+ struct llm_build_ernie4_5 : public llm_graph_context {
14172+ llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14173+ const int64_t n_embd_head = hparams.n_embd_head_v;
14174+
14175+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14176+ GGML_ASSERT(n_embd_head == hparams.n_rot);
14177+
14178+ ggml_tensor * cur;
14179+ ggml_tensor * inpL;
14180+
14181+ inpL = build_inp_embd(model.tok_embd);
14182+
14183+ // inp_pos - contains the positions
14184+ ggml_tensor * inp_pos = build_inp_pos();
14185+
14186+ auto * inp_attn = build_attn_inp_kv_unified();
14187+
14188+ for (int il = 0; il < n_layer; ++il) {
14189+ ggml_tensor * inpSA = inpL;
14190+
14191+ // norm
14192+ {
14193+ cur = build_norm(inpL,
14194+ model.layers[il].attn_norm, NULL,
14195+ LLM_NORM_RMS, il);
14196+ cb(cur, "attn_norm", il);
14197+ }
14198+
14199+ // self-attention
14200+ {
14201+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14202+ cb(Qcur, "Qcur", il);
14203+ if (model.layers[il].bq) {
14204+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14205+ cb(Qcur, "Qcur", il);
14206+ }
14207+
14208+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14209+ cb(Kcur, "Kcur", il);
14210+ if (model.layers[il].bk) {
14211+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14212+ cb(Kcur, "Kcur", il);
14213+ }
14214+
14215+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14216+ cb(Vcur, "Vcur", il);
14217+ if (model.layers[il].bv) {
14218+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14219+ cb(Vcur, "Vcur", il);
14220+ }
14221+
14222+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14223+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14224+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14225+
14226+ Qcur = ggml_rope_ext(
14227+ ctx0, Qcur, inp_pos, nullptr,
14228+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14229+ ext_factor, attn_factor, beta_fast, beta_slow
14230+ );
14231+
14232+ Kcur = ggml_rope_ext(
14233+ ctx0, Kcur, inp_pos, nullptr,
14234+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14235+ ext_factor, attn_factor, beta_fast, beta_slow
14236+ );
14237+
14238+ cb(Qcur, "Qcur", il);
14239+ cb(Kcur, "Kcur", il);
14240+ cb(Vcur, "Vcur", il);
14241+
14242+ cur = build_attn(inp_attn, gf,
14243+ model.layers[il].wo, NULL,
14244+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14245+ }
14246+
14247+ if (il == n_layer - 1) {
14248+ // skip computing output for unused tokens
14249+ ggml_tensor * inp_out_ids = build_inp_out_ids();
14250+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14251+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14252+ }
14253+
14254+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14255+ cb(ffn_inp, "ffn_inp", il);
14256+
14257+ // feed-forward network
14258+ {
14259+ cur = build_norm(ffn_inp,
14260+ model.layers[il].ffn_norm, NULL,
14261+ LLM_NORM_RMS, il);
14262+ cb(cur, "ffn_norm", il);
14263+
14264+ cur = build_ffn(cur,
14265+ model.layers[il].ffn_up, NULL, NULL,
14266+ model.layers[il].ffn_gate, NULL, NULL,
14267+ model.layers[il].ffn_down, NULL, NULL,
14268+ NULL,
14269+ LLM_FFN_SILU, LLM_FFN_PAR, il);
14270+ cb(cur, "ffn_out", il);
14271+ }
14272+
14273+ cur = ggml_add(ctx0, cur, ffn_inp);
14274+
14275+ cur = build_cvec(cur, il);
14276+ cb(cur, "l_out", il);
14277+
14278+ // input for next layer
14279+ inpL = cur;
14280+ }
14281+
14282+ cur = inpL;
14283+
14284+ cur = build_norm(cur,
14285+ model.output_norm, NULL,
14286+ LLM_NORM_RMS, -1);
14287+
14288+ cb(cur, "result_norm", -1);
14289+ res->t_embd = cur;
14290+
14291+ // lm_head
14292+ cur = build_lora_mm(model.output, cur);
14293+
14294+ cb(cur, "result_output", -1);
14295+ res->t_logits = cur;
14296+
14297+ ggml_build_forward_expand(gf, cur);
14298+ }
14299+ };
14300+
1412914301struct llm_build_arcee : public llm_graph_context {
1413014302 llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
1413114303 const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14636,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
1463614808 {
1463714809 llm = std::make_unique<llm_build_arcee>(*this, params, gf);
1463814810 } break;
14811+ case LLM_ARCH_ERNIE4_5:
14812+ {
14813+ llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
14814+ } break;
1463914815 default:
1464014816 GGML_ABORT("fatal error");
1464114817 }
@@ -14787,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1478714963 case LLM_ARCH_BAILINGMOE:
1478814964 case LLM_ARCH_NEO_BERT:
1478914965 case LLM_ARCH_ARCEE:
14966+ case LLM_ARCH_ERNIE4_5:
1479014967 return LLAMA_ROPE_TYPE_NORM;
1479114968
1479214969 // the pairs of head values are offset by n_rot/2
0 commit comments