@@ -582,43 +582,33 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
582582 continue ;
583583 }
584584
585- // keep track of what the minimum sequence positions would be if we accept the ubatch
586- llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
587- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
588- seq_pos_min[s] = cells.seq_pos_min (s);
589- }
590-
591585 bool found = true ;
592586 for (uint32_t i = 0 ; i < n_tokens; i++) {
593- const llama_pos pos = ubatch.pos [i];
594- const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
587+ // const llama_pos pos = ubatch.pos[i];
588+ // const llama_seq_id seq_id = ubatch.seq_id[i][0];
595589
596590 // can we use this cell? either:
597591 // - the cell is empty
598592 // - the cell is occupied only by one sequence:
599- // - mask causally, if the sequence is the same as the one we are inserting
593+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
600594 // - mask SWA, using current max pos for that sequence in the cache
601595 // always insert in the cell with minimum pos
602596 bool can_use = cells.is_empty (head_cur + i);
603597
604598 if (!can_use && cells.seq_count (head_cur + i) == 1 ) {
605599 const llama_pos pos_cell = cells.pos_get (head_cur + i);
606600
607- // causal mask
608- if (cells.seq_has (head_cur + i, seq_id)) {
609- can_use = pos_cell >= pos;
610- }
601+ // (disabled) causal mask
602+ // note: it's better to purge any "future" tokens beforehand
603+ // if (cells.seq_has(head_cur + i, seq_id)) {
604+ // can_use = pos_cell >= pos;
605+ // }
611606
612607 if (!can_use) {
613608 const llama_seq_id seq_id_cell = cells.seq_get (head_cur + i);
614609
615610 // SWA mask
616- // note: we insert only in the cell with minimum pos in order to preserve the invariant that
617- // all positions between [pos_min, pos_max] for each sequence will be present in the cache
618- // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
619- if (pos_cell == seq_pos_min[seq_id_cell] &&
620- is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
621- seq_pos_min[seq_id_cell]++;
611+ if (is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
622612 can_use = true ;
623613 }
624614 }
@@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
646636}
647637
648638void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur, const llama_ubatch & ubatch) {
639+ // keep track of the max sequence position that we would overwrite with this ubatch
640+ // for non-SWA cache, this would be always empty
641+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
642+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
643+ seq_pos_max_rm[s] = -1 ;
644+ }
645+
649646 for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
650647 if (!cells.is_empty (head_cur + i)) {
648+ assert (cells.seq_count (head_cur + i) == 1 );
649+
650+ const llama_seq_id seq_id = cells.seq_get (head_cur + i);
651+ const llama_pos pos = cells.pos_get (head_cur + i);
652+
653+ seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
654+
651655 cells.rm (head_cur + i);
652656 }
653657
@@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
658662 }
659663 }
660664
665+ // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
666+ // will be present in the cache. so we have to purge any position which is less than those we would overwrite
667+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
668+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
669+ if (seq_pos_max_rm[s] == -1 ) {
670+ continue ;
671+ }
672+
673+ if (cells.seq_pos_min (s) <= seq_pos_max_rm[s]) {
674+ LLAMA_LOG_DEBUG (" %s: purging positions [%d, %d] of sequence %d from KV cache\n " ,
675+ __func__, cells.seq_pos_min (s), seq_pos_max_rm[s], s);
676+
677+ seq_rm (s, cells.seq_pos_min (s), seq_pos_max_rm[s] + 1 );
678+ }
679+ }
680+
661681 // move the head at the end of the slot
662682 head = head_cur + ubatch.n_tokens ;
663683}
0 commit comments