@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281}
282282
283283void  llm_graph_input_attn_kv_unified::set_input (const  llama_ubatch * ubatch) {
284-     if  (self_kq_mask) {
285-         mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
286-     }
284+     mctx->set_input_k_idxs (self_k_idxs, ubatch);
285+     mctx->set_input_v_idxs (self_v_idxs, ubatch);
286+ 
287+     mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
287288}
288289
289290void  llm_graph_input_attn_kv_unified_iswa::set_input (const  llama_ubatch * ubatch) {
290-     if  (self_kq_mask) {
291-         mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
292-     }
291+     mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
292+     mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
293293
294-     if  (self_kq_mask_swa) {
295-         mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
296-     }
294+     mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
295+ 
296+     mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
297+     mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
298+ 
299+     mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
297300}
298301
299302void  llm_graph_input_attn_cross::set_input (const  llama_ubatch * ubatch) {
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333336}
334337
335338void  llm_graph_input_mem_hybrid::set_input (const  llama_ubatch * ubatch) {
336-     if  (self_kq_mask) {
337-         mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
338-     }
339+     mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
340+     mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
341+ 
342+     mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
339343
340344    const  int64_t  n_rs = mctx->get_recr ()->get_n_rs ();
341345
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350354    }
351355}
352356
353- void  llm_graph_input_one::set_input (const  llama_ubatch *) {
357+ void  llm_graph_input_one::set_input (const  llama_ubatch * ubatch) {
358+     GGML_UNUSED (ubatch);
354359    GGML_ASSERT (one && ggml_nelements (one) == 1 );
355360    float  f_one = 1 .0f ;
356361    ggml_backend_tensor_set (one, &f_one, 0 , sizeof (float ));
@@ -997,6 +1002,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
9971002
9981003        const  auto  n_kv = inp->mctx ->get_attn ()->get_n_kv ();
9991004
1005+         inp->self_k_idxs  = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1006+         inp->self_v_idxs  = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1007+ 
10001008        inp->self_kq_mask  = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
10011009        // cb(inp->self_kq_mask, "KQ_mask", -1);
10021010        ggml_set_input (inp->self_kq_mask );
@@ -1198,8 +1206,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
11981206
11991207        const  auto  n_kv = mctx_cur->get_n_kv ();
12001208
1209+         inp->self_k_idxs  = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1210+         inp->self_v_idxs  = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1211+ 
12011212        inp->self_kq_mask  = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1202-         // cb(inp->self_kq_mask, "KQ_mask", -1);
12031213        ggml_set_input (inp->self_kq_mask );
12041214
12051215        inp->self_kq_mask_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1230,8 +1240,11 @@ ggml_tensor * llm_graph_context::build_attn(
12301240
12311241    //  store to KV cache
12321242    {
1233-         ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1234-         ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1243+         const  auto  & k_idxs = inp->get_k_idxs ();
1244+         const  auto  & v_idxs = inp->get_v_idxs ();
1245+ 
1246+         ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1247+         ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
12351248    }
12361249
12371250    const  auto  & kq_mask = inp->get_kq_mask ();
@@ -1290,11 +1303,15 @@ ggml_tensor * llm_graph_context::build_attn(
12901303
12911304    //  optionally store to KV cache
12921305    if  (k_cur) {
1293-         ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1306+         const  auto  & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
1307+ 
1308+         ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
12941309    }
12951310
12961311    if  (v_cur) {
1297-         ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1312+         const  auto  & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
1313+ 
1314+         ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
12981315    }
12991316
13001317    const  auto  & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1398,8 +1415,11 @@ ggml_tensor * llm_graph_context::build_attn(
13981415
13991416    //  store to KV cache
14001417    {
1401-         ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1402-         ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1418+         const  auto  & k_idxs = inp->get_k_idxs ();
1419+         const  auto  & v_idxs = inp->get_v_idxs ();
1420+ 
1421+         ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1422+         ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
14031423    }
14041424
14051425    const  auto  & kq_mask = inp->get_kq_mask ();
@@ -1434,8 +1454,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14341454    {
14351455        const  auto  n_kv = mctx_cur->get_base ()->get_n_kv ();
14361456
1457+         inp->self_k_idxs  = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1458+         inp->self_v_idxs  = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1459+ 
14371460        inp->self_kq_mask  = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1438-         // cb(inp->self_kq_mask, "KQ_mask", -1);
14391461        ggml_set_input (inp->self_kq_mask );
14401462
14411463        inp->self_kq_mask_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1446,8 +1468,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14461468
14471469        const  auto  n_kv = mctx_cur->get_swa ()->get_n_kv ();
14481470
1471+         inp->self_k_idxs_swa  = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1472+         inp->self_v_idxs_swa  = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1473+ 
14491474        inp->self_kq_mask_swa  = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1450-         // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14511475        ggml_set_input (inp->self_kq_mask_swa );
14521476
14531477        inp->self_kq_mask_swa_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments