@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281}
282282
283283void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284- if (self_kq_mask) {
285- mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
286- }
284+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
285+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
286+
287+ mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
287288}
288289
289290void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
290- if (self_kq_mask) {
291- mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
292- }
291+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
292+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
293293
294- if (self_kq_mask_swa) {
295- mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
296- }
294+ mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
295+
296+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
297+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
298+
299+ mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
297300}
298301
299302void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333336}
334337
335338void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
336- if (self_kq_mask) {
337- mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
338- }
339+ mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
340+ mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
341+
342+ mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
339343
340344 const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
341345
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350354 }
351355}
352356
353- void llm_graph_input_one::set_input (const llama_ubatch *) {
357+ void llm_graph_input_one::set_input (const llama_ubatch * ubatch) {
358+ GGML_UNUSED (ubatch);
354359 GGML_ASSERT (one && ggml_nelements (one) == 1 );
355360 float f_one = 1 .0f ;
356361 ggml_backend_tensor_set (one, &f_one, 0 , sizeof (float ));
@@ -997,6 +1002,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
9971002
9981003 const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
9991004
1005+ inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1006+ inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1007+
10001008 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
10011009 // cb(inp->self_kq_mask, "KQ_mask", -1);
10021010 ggml_set_input (inp->self_kq_mask );
@@ -1198,8 +1206,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
11981206
11991207 const auto n_kv = mctx_cur->get_n_kv ();
12001208
1209+ inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1210+ inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1211+
12011212 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1202- // cb(inp->self_kq_mask, "KQ_mask", -1);
12031213 ggml_set_input (inp->self_kq_mask );
12041214
12051215 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1230,8 +1240,11 @@ ggml_tensor * llm_graph_context::build_attn(
12301240
12311241 // store to KV cache
12321242 {
1233- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1234- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1243+ const auto & k_idxs = inp->get_k_idxs ();
1244+ const auto & v_idxs = inp->get_v_idxs ();
1245+
1246+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1247+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
12351248 }
12361249
12371250 const auto & kq_mask = inp->get_kq_mask ();
@@ -1290,11 +1303,15 @@ ggml_tensor * llm_graph_context::build_attn(
12901303
12911304 // optionally store to KV cache
12921305 if (k_cur) {
1293- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1306+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
1307+
1308+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
12941309 }
12951310
12961311 if (v_cur) {
1297- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1312+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
1313+
1314+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
12981315 }
12991316
13001317 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1398,8 +1415,11 @@ ggml_tensor * llm_graph_context::build_attn(
13981415
13991416 // store to KV cache
14001417 {
1401- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1402- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1418+ const auto & k_idxs = inp->get_k_idxs ();
1419+ const auto & v_idxs = inp->get_v_idxs ();
1420+
1421+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1422+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
14031423 }
14041424
14051425 const auto & kq_mask = inp->get_kq_mask ();
@@ -1434,8 +1454,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14341454 {
14351455 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14361456
1457+ inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1458+ inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1459+
14371460 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1438- // cb(inp->self_kq_mask, "KQ_mask", -1);
14391461 ggml_set_input (inp->self_kq_mask );
14401462
14411463 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1446,8 +1468,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14461468
14471469 const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
14481470
1471+ inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1472+ inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1473+
14491474 inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1450- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14511475 ggml_set_input (inp->self_kq_mask_swa );
14521476
14531477 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments