@@ -2455,6 +2455,7 @@ struct llama_layer {
24552455 // long rope factors
24562456 struct ggml_tensor * rope_long = nullptr;
24572457 struct ggml_tensor * rope_short = nullptr;
2458+ struct ggml_tensor * rope_freqs = nullptr;
24582459
24592460 // bitnet scale
24602461 struct ggml_tensor * wq_scale;
@@ -6055,6 +6056,8 @@ static bool llm_load_tensors(
60556056
60566057 layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
60576058
6059+ layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), { n_embd/n_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
6060+
60586061 if (n_expert == 0) {
60596062 layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
60606063 layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
@@ -8532,6 +8535,10 @@ struct llm_build_context {
85328535 // choose long/short freq factors based on the context size
85338536 const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
85348537
8538+ if (model.layers[il].rope_freqs != nullptr) {
8539+ return model.layers[il].rope_freqs;
8540+ }
8541+
85358542 if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
85368543 return model.layers[il].rope_long;
85378544 }
@@ -8726,6 +8733,9 @@ struct llm_build_context {
87268733
87278734 // self-attention
87288735 {
8736+ // rope freq factors for llama3; may return nullptr for llama2 and other models
8737+ struct ggml_tensor * rope_factors = build_rope_factors(il);
8738+
87298739 // compute Q and K and RoPE them
87308740 struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
87318741 cb(Qcur, "Qcur", il);
@@ -8749,14 +8759,14 @@ struct llm_build_context {
87498759 }
87508760
87518761 Qcur = ggml_rope_ext(
8752- ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr ,
8762+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors ,
87538763 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
87548764 ext_factor, attn_factor, beta_fast, beta_slow
87558765 );
87568766 cb(Qcur, "Qcur", il);
87578767
87588768 Kcur = ggml_rope_ext(
8759- ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr ,
8769+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors ,
87608770 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
87618771 ext_factor, attn_factor, beta_fast, beta_slow
87628772 );
0 commit comments