@@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
12831283    const  int64_t  n_tps     = n_tokens/n_stream;
12841284    const  int64_t  n_tps_pad = GGML_PAD (n_tps, GGML_KQ_MASK_PAD);
12851285
1286+     std::fill (data, data + ggml_nelements (dst), -INFINITY);
1287+ 
12861288    //  Use only the previous KV cells of the correct sequence for each token of the ubatch.
12871289    //  It's assumed that if a token in the batch has multiple sequences, they are equivalent.
12881290    //  Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
13061308
13071309                const  llama_pos p1 = ubatch->pos [i];
13081310
1309-                 for  (uint32_t  j = 0 ; j < n_kv; ++j) {
1310-                     float  f = 0 .0f ;
1311- 
1312-                     bool  masked = false ;
1311+                 const  uint64_t  idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
13131312
1313+                 for  (uint32_t  j = 0 ; j < n_kv; ++j) {
13141314                    if  (cells.is_empty (j)) {
1315-                         masked = true ;
1316-                     } else  {
1317-                         const  llama_pos p0 = cells.pos_get (j);
1318- 
1319-                         //  mask the token if not the same sequence
1320-                         masked = masked || (!cells.seq_has (j, seq_id));
1315+                         continue ;
1316+                     }
13211317
1322-                         //  mask future tokens
1323-                         masked = masked || (causal_attn && p0 > p1);
1318+                     //  mask the token if not the same sequence
1319+                     if  (!cells.seq_has (j, seq_id)) {
1320+                         continue ;
1321+                     }
13241322
1325-                         //  apply SWA if any
1326-                         masked = masked || (is_masked_swa (p0, p1));
1323+                     const  llama_pos p0 = cells.pos_get (j);
13271324
1328-                          if  (!masked && hparams. use_alibi ) { 
1329-                             f = - std::abs ( p0 -  p1); 
1330-                         } 
1325+                     //  mask future tokens 
1326+                     if  (causal_attn &&  p0 >  p1) { 
1327+                         continue ; 
13311328                    }
13321329
1333-                     if  (masked) {
1334-                         f = -INFINITY;
1330+                     //  apply SWA if any
1331+                     if  (is_masked_swa (p0, p1)) {
1332+                         continue ;
13351333                    }
13361334
1337-                     data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
1338-                 }
1339- 
1340-                 //  mask padded tokens
1341-                 if  (data) {
1342-                     for  (uint32_t  ii = n_tps; ii < n_tps_pad; ++ii) {
1343-                         for  (uint32_t  j = 0 ; j < n_kv; ++j) {
1344-                             data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
1345-                         }
1346-                     }
1335+                     data[idst + j] = hparams.use_alibi  ? -std::abs (p0 - p1) : 0 .0f ;
13471336                }
13481337            }
13491338        }
0 commit comments