@@ -68,6 +68,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868
6969 cells.resize (kv_size);
7070
71+ if (supports_set_rows) {
72+ // TODO: this requirement can be relaxed, but it would be much easier to implement when we have an actual
73+ // model that needs this
74+ GGML_ASSERT (hparams.is_n_embd_k_gqa_homogeneous () && hparams.is_n_embd_v_gqa_homogeneous ());
75+ }
76+
7177 for (uint32_t il = 0 ; il < n_layer_cache; il++) {
7278 if (filter && !filter (il)) {
7379 LLAMA_LOG_DEBUG (" %s: layer %3d: skipped\n " , __func__, il);
@@ -98,8 +104,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98104 ggml_tensor * k;
99105 ggml_tensor * v;
100106
101- k = ggml_new_tensor_2d (ctx, type_k, n_embd_k_gqa, kv_size);
102- v = ggml_new_tensor_2d (ctx, type_v, n_embd_v_gqa, kv_size);
107+ k = ggml_new_tensor_3d (ctx, type_k, n_embd_k_gqa, kv_size, 1 );
108+ v = ggml_new_tensor_3d (ctx, type_v, n_embd_v_gqa, kv_size, 1 );
103109
104110 ggml_format_name (k, " cache_k_l%d" , il);
105111 ggml_format_name (v, " cache_v_l%d" , il);
@@ -780,33 +786,40 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
780786
781787 auto * k = layers[ikv].k ;
782788
783- return ggml_view_3d (ctx, k,
784- hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv,
789+ const uint64_t kv_size = get_size ();
790+
791+ return ggml_view_4d (ctx, k,
792+ hparams.n_embd_head_k , hparams.n_head_kv (il), n_kv, 1 ,
785793 ggml_row_size (k->type , hparams.n_embd_head_k ),
786794 ggml_row_size (k->type , hparams.n_embd_k_gqa (il)),
787- 0 );
795+ ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*kv_size),
796+ ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*kv_size)*0 );
788797}
789798
790799ggml_tensor * llama_kv_cache_unified::get_v (ggml_context * ctx, int32_t il, uint32_t n_kv) const {
791800 const int32_t ikv = map_layer_ids.at (il);
792801
793802 auto * v = layers[ikv].v ;
794803
804+ const uint64_t kv_size = get_size ();
805+
795806 if (!v_trans) {
796807 // note: v->nb[1] <= v->nb[2]
797- return ggml_view_3d (ctx, v,
798- hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv,
799- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
800- ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
801- 0 );
808+ return ggml_view_4d (ctx, v,
809+ hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, 1 ,
810+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
811+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
812+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*kv_size), // v->nb[3]
813+ ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*kv_size)*0 );
802814 }
803815
804816 // note: v->nb[1] > v->nb[2]
805- return ggml_view_3d (ctx, v,
806- n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v ,
807- ggml_row_size (v->type , v->ne [1 ]*hparams.n_embd_head_v ), // v->nb[1]
808- ggml_row_size (v->type , v->ne [1 ]), // v->nb[2]
809- 0 );
817+ return ggml_view_4d (ctx, v,
818+ n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , 1 ,
819+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
820+ ggml_row_size (v->type , kv_size), // v->nb[2]
821+ ggml_row_size (v->type , kv_size*hparams.n_embd_v_gqa (il)), // v->nb[3]
822+ ggml_row_size (v->type , kv_size*hparams.n_embd_v_gqa (il))*0 );
810823}
811824
812825ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -820,6 +833,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
820833 k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
821834
822835 if (k_idxs && supports_set_rows) {
836+ if (k->ne [2 ] > 1 ) {
837+ k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
838+ }
839+
823840 return ggml_set_rows (ctx, k, k_cur, k_idxs);
824841 }
825842
@@ -845,24 +862,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
845862
846863 if (v_idxs && supports_set_rows) {
847864 if (!v_trans) {
865+ if (v->ne [2 ] > 1 ) {
866+ v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
867+ }
868+
848869 return ggml_set_rows (ctx, v, v_cur, v_idxs);
849870 }
850871
851872 // the row becomes a single element
852- ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
853-
854- // note: the V cache is transposed when not using flash attention
855- v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
873+ ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v->ne [0 ]*v->ne [1 ]*v->ne [2 ]);
856874
857- // note: we can be more explicit here at the cost of extra cont
858- // however, above we take advantage that a row of single element is always continuous regardless of the row stride
859- // v_cur = ggml_transpose(ctx, v_cur);
860- // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
875+ v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur->ne [0 ]*v_cur->ne [1 ]);
861876
862- // we broadcast the KV indices n_embd_v_gqa times
863- // v [1, n_kv, n_embd_v_gqa]
864- // v_cur [1, n_tokens, n_embd_v_gqa]
865- // v_idxs [n_tokens, 1, 1]
866877 return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
867878 }
868879
@@ -899,7 +910,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
899910ggml_tensor * llama_kv_cache_unified::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
900911 const uint32_t n_tokens = ubatch.n_tokens ;
901912
902- ggml_tensor * v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
913+ ggml_tensor * v_idxs;
914+
915+ if (!v_trans) {
916+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
917+ } else {
918+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa ());
919+ }
903920
904921 ggml_set_input (v_idxs);
905922
@@ -916,7 +933,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
916933 GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
917934 int64_t * data = (int64_t *) dst->data ;
918935
919- for (int64_t i = 0 ; i < n_tokens; ++i) {
936+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
920937 data[i] = sinfo.idxs .at (i);
921938 }
922939}
@@ -931,8 +948,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
931948 GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
932949 int64_t * data = (int64_t *) dst->data ;
933950
934- for (int64_t i = 0 ; i < n_tokens; ++i) {
935- data[i] = sinfo.idxs .at (i);
951+ if (!v_trans) {
952+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
953+ data[i] = sinfo.idxs .at (i);
954+ }
955+ } else {
956+ // note: the V cache is transposed when not using flash attention
957+ const int64_t kv_size = get_size ();
958+
959+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa ();
960+
961+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
962+ for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
963+ data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs .at (i);
964+ }
965+ }
936966 }
937967}
938968
0 commit comments