@@ -1278,6 +1278,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
12781278    const  int64_t  n_tps     = n_tokens/n_stream;
12791279    const  int64_t  n_tps_pad = GGML_PAD (n_tps, GGML_KQ_MASK_PAD);
12801280
1281+     std::fill (data, data + ggml_nelements (dst), -INFINITY);
1282+ 
12811283    //  Use only the previous KV cells of the correct sequence for each token of the ubatch.
12821284    //  It's assumed that if a token in the batch has multiple sequences, they are equivalent.
12831285    //  Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1301,44 +1303,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
13011303
13021304                const  llama_pos p1 = ubatch->pos [i];
13031305
1304-                 for  (uint32_t  j = 0 ; j < n_kv; ++j) {
1305-                     float  f = 0 .0f ;
1306- 
1307-                     bool  masked = false ;
1306+                 const  uint64_t  idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
13081307
1308+                 for  (uint32_t  j = 0 ; j < n_kv; ++j) {
13091309                    if  (cells.is_empty (j)) {
1310-                         masked = true ;
1311-                     } else  {
1312-                         const  llama_pos p0 = cells.pos_get (j);
1313- 
1314-                         //  mask the token if not the same sequence
1315-                         masked = masked || (!cells.seq_has (j, seq_id));
1310+                         continue ;
1311+                     }
13161312
1317-                         //  mask future tokens
1318-                         masked = masked || (causal_attn && p0 > p1);
1313+                     //  mask the token if not the same sequence
1314+                     if  (!cells.seq_has (j, seq_id)) {
1315+                         continue ;
1316+                     }
13191317
1320-                         //  apply SWA if any
1321-                         masked = masked || (is_masked_swa (p0, p1));
1318+                     const  llama_pos p0 = cells.pos_get (j);
13221319
1323-                          if  (!masked && hparams. use_alibi ) { 
1324-                             f = - std::abs ( p0 -  p1); 
1325-                         } 
1320+                     //  mask future tokens 
1321+                     if  (causal_attn &&  p0 >  p1) { 
1322+                         continue ; 
13261323                    }
13271324
1328-                     if  (masked) {
1329-                         f = -INFINITY;
1325+                     //  apply SWA if any
1326+                     if  (is_masked_swa (p0, p1)) {
1327+                         continue ;
13301328                    }
13311329
1332-                     data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
1333-                 }
1334- 
1335-                 //  mask padded tokens
1336-                 if  (data) {
1337-                     for  (uint32_t  ii = n_tps; ii < n_tps_pad; ++ii) {
1338-                         for  (uint32_t  j = 0 ; j < n_kv; ++j) {
1339-                             data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
1340-                         }
1341-                     }
1330+                     data[idst + j] = hparams.use_alibi  ? -std::abs (p0 - p1) : 0 .0f ;
13421331                }
13431332            }
13441333        }
0 commit comments