Skip to content

Commit 64d18c4

Browse files
committed
kv-cache : opt mask set input
ggml-ci
1 parent ab82dc2 commit 64d18c4

File tree

1 file changed

+18
-29
lines changed

1 file changed

+18
-29
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
12781278
const int64_t n_tps = n_tokens/n_stream;
12791279
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
12801280

1281+
std::fill(data, data + ggml_nelements(dst), -INFINITY);
1282+
12811283
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
12821284
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
12831285
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1301,44 +1303,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
13011303

13021304
const llama_pos p1 = ubatch->pos[i];
13031305

1304-
for (uint32_t j = 0; j < n_kv; ++j) {
1305-
float f = 0.0f;
1306-
1307-
bool masked = false;
1306+
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
13081307

1308+
for (uint32_t j = 0; j < n_kv; ++j) {
13091309
if (cells.is_empty(j)) {
1310-
masked = true;
1311-
} else {
1312-
const llama_pos p0 = cells.pos_get(j);
1313-
1314-
// mask the token if not the same sequence
1315-
masked = masked || (!cells.seq_has(j, seq_id));
1310+
continue;
1311+
}
13161312

1317-
// mask future tokens
1318-
masked = masked || (causal_attn && p0 > p1);
1313+
// mask the token if not the same sequence
1314+
if (!cells.seq_has(j, seq_id)) {
1315+
continue;
1316+
}
13191317

1320-
// apply SWA if any
1321-
masked = masked || (is_masked_swa(p0, p1));
1318+
const llama_pos p0 = cells.pos_get(j);
13221319

1323-
if (!masked && hparams.use_alibi) {
1324-
f = -std::abs(p0 - p1);
1325-
}
1320+
// mask future tokens
1321+
if (causal_attn && p0 > p1) {
1322+
continue;
13261323
}
13271324

1328-
if (masked) {
1329-
f = -INFINITY;
1325+
// apply SWA if any
1326+
if (is_masked_swa(p0, p1)) {
1327+
continue;
13301328
}
13311329

1332-
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
1333-
}
1334-
1335-
// mask padded tokens
1336-
if (data) {
1337-
for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
1338-
for (uint32_t j = 0; j < n_kv; ++j) {
1339-
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
1340-
}
1341-
}
1330+
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
13421331
}
13431332
}
13441333
}

0 commit comments

Comments
 (0)