@@ -68,11 +68,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868
6969 cells.resize (kv_size);
7070
71- if (supports_set_rows) {
72- // TODO: this requirement can be relaxed, but it would be much easier to implement when we have an actual
73- // model that needs this
74- // ref: https://github.com/ggml-org/llama.cpp/pull/14517
75- GGML_ASSERT (hparams.is_n_embd_v_gqa_homogeneous ());
71+ // [TAG_V_CACHE_VARIABLE]
72+ if (v_trans && hparams.is_n_embd_v_gqa_variable ()) {
73+ LLAMA_LOG_WARN (" %s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n " ,
74+ __func__, hparams.n_embd_v_gqa_max ());
7675 }
7776
7877 for (uint32_t il = 0 ; il < n_layer_cache; il++) {
@@ -81,8 +80,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
8180 continue ;
8281 }
8382
84- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
85- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
83+ // [TAG_V_CACHE_VARIABLE]
84+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
85+ const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa (il) : hparams.n_embd_v_gqa_max ();
8686
8787 const char * dev_name = " CPU" ;
8888
@@ -808,19 +808,19 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808808 // note: v->nb[1] <= v->nb[2]
809809 return ggml_view_4d (ctx, v,
810810 hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, 1 ,
811- ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
812- ggml_row_size (v->type , hparams. n_embd_v_gqa (il)), // v->nb[2]
813- ggml_row_size (v->type , hparams. n_embd_v_gqa (il) *kv_size), // v->nb[3]
814- ggml_row_size (v->type , hparams. n_embd_v_gqa (il) *kv_size)*0 );
811+ ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
812+ ggml_row_size (v->type , v-> ne [ 0 ]), // v->nb[2]
813+ ggml_row_size (v->type , v-> ne [ 0 ] *kv_size), // v->nb[3]
814+ ggml_row_size (v->type , v-> ne [ 0 ] *kv_size)*0 );
815815 }
816816
817817 // note: v->nb[1] > v->nb[2]
818818 return ggml_view_4d (ctx, v,
819819 n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , 1 ,
820- ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
821- ggml_row_size (v->type , kv_size), // v->nb[2]
822- ggml_row_size (v->type , kv_size*hparams. n_embd_v_gqa (il)), // v->nb[3]
823- ggml_row_size (v->type , kv_size*hparams. n_embd_v_gqa (il) )*0 );
820+ ggml_row_size (v->type , kv_size*hparams.n_embd_head_v ), // v->nb[1]
821+ ggml_row_size (v->type , kv_size), // v->nb[2]
822+ ggml_row_size (v->type , kv_size*v-> ne [ 0 ]), // v->nb[3]
823+ ggml_row_size (v->type , kv_size*v-> ne [ 0 ] )*0 );
824824}
825825
826826ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -856,8 +856,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
856856
857857 auto * v = layers[ikv].v ;
858858
859- const int64_t n_embd_v_gqa = v ->ne [0 ];
860- const int64_t n_tokens = v_cur->ne [2 ];
859+ const int64_t n_embd_v_gqa = v_cur ->ne [0 ]*v_cur-> ne [ 1 ];
860+ const int64_t n_tokens = v_cur->ne [2 ];
861861
862862 v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
863863
@@ -870,6 +870,11 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
870870 return ggml_set_rows (ctx, v, v_cur, v_idxs);
871871 }
872872
873+ // [TAG_V_CACHE_VARIABLE]
874+ if (n_embd_v_gqa < v->ne [0 ]) {
875+ v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa, 0 , 0 , 0 );
876+ }
877+
873878 // the row becomes a single element
874879 ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v->ne [0 ]*v->ne [1 ]*v->ne [2 ]);
875880
@@ -916,7 +921,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con
916921 if (!v_trans) {
917922 v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
918923 } else {
919- v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa ());
924+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max ());
920925 }
921926
922927 ggml_set_input (v_idxs);
@@ -957,7 +962,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
957962 // note: the V cache is transposed when not using flash attention
958963 const int64_t kv_size = get_size ();
959964
960- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa ();
965+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max ();
961966
962967 for (uint32_t i = 0 ; i < n_tokens; ++i) {
963968 for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
0 commit comments