@@ -1283,8 +1283,8 @@ static bool llama_kv_cache_init(
12831283// find an empty slot of size "n_tokens" in the cache
12841284// updates the cache head
12851285static bool llama_kv_cache_find_slot (
1286- struct llama_kv_cache & cache,
1287- const struct llama_batch & batch) {
1286+ struct llama_kv_cache & cache,
1287+ const struct llama_batch & batch) {
12881288 const uint32_t n_ctx = cache.size ;
12891289 const uint32_t n_tokens = batch.n_tokens ;
12901290
@@ -1352,10 +1352,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
13521352}
13531353
13541354static void llama_kv_cache_seq_rm (
1355- struct llama_kv_cache & cache,
1356- llama_seq_id seq_id,
1357- llama_pos p0,
1358- llama_pos p1) {
1355+ struct llama_kv_cache & cache,
1356+ llama_seq_id seq_id,
1357+ llama_pos p0,
1358+ llama_pos p1) {
1359+ if (p0 < 0 ) p0 = 0 ;
1360+ if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1361+
13591362 for (uint32_t i = 0 ; i < cache.size ; ++i) {
13601363 if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
13611364 cache.cells [i].seq_id .erase (seq_id);
@@ -1367,11 +1370,14 @@ static void llama_kv_cache_seq_rm(
13671370}
13681371
13691372static void llama_kv_cache_seq_cp (
1370- struct llama_kv_cache & cache,
1371- llama_seq_id seq_id_src,
1372- llama_seq_id seq_id_dst,
1373- llama_pos p0,
1374- llama_pos p1) {
1373+ struct llama_kv_cache & cache,
1374+ llama_seq_id seq_id_src,
1375+ llama_seq_id seq_id_dst,
1376+ llama_pos p0,
1377+ llama_pos p1) {
1378+ if (p0 < 0 ) p0 = 0 ;
1379+ if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1380+
13751381 for (uint32_t i = 0 ; i < cache.size ; ++i) {
13761382 if (cache.cells [i].has_seq_id (seq_id_src) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
13771383 cache.cells [i].seq_id .insert (seq_id_dst);
@@ -1389,11 +1395,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
13891395}
13901396
13911397static void llama_kv_cache_seq_shift (
1392- struct llama_kv_cache & cache,
1393- llama_seq_id seq_id,
1394- llama_pos p0,
1395- llama_pos p1,
1396- llama_pos delta) {
1398+ struct llama_kv_cache & cache,
1399+ llama_seq_id seq_id,
1400+ llama_pos p0,
1401+ llama_pos p1,
1402+ llama_pos delta) {
1403+ if (p0 < 0 ) p0 = 0 ;
1404+ if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1405+
13971406 for (uint32_t i = 0 ; i < cache.size ; ++i) {
13981407 if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
13991408 cache.cells [i].pos += delta;
@@ -7209,16 +7218,6 @@ struct llama_data_file_context : llama_data_context {
72097218 *
72107219*/
72117220static void llama_copy_state_data_internal (struct llama_context * ctx, llama_data_context * data_ctx) {
7212- // TODO: does not support multi-sequence states
7213- {
7214- const auto & kv_self = ctx->kv_self ;
7215- for (uint32_t i = 0 ; i < kv_self.head ; ++i) {
7216- GGML_ASSERT (kv_self.cells [i].pos == (int32_t ) i);
7217- GGML_ASSERT (kv_self.cells [i].seq_id .size () == 1 );
7218- GGML_ASSERT (kv_self.cells [i].has_seq_id (0 ));
7219- }
7220- }
7221-
72227221 // copy rng
72237222 {
72247223 std::stringstream rng_ss;
@@ -7271,36 +7270,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
72717270 const auto & hparams = ctx->model .hparams ;
72727271 const auto & cparams = ctx->cparams ;
72737272
7274- const int n_layer = hparams.n_layer ;
7275- const int n_embd = hparams.n_embd_gqa ();
7276- const int n_ctx = cparams.n_ctx ;
7273+ const auto n_layer = hparams.n_layer ;
7274+ const auto n_embd = hparams.n_embd_gqa ();
7275+ const auto n_ctx = cparams.n_ctx ;
72777276
7278- const size_t kv_size = kv_self.buf .size ;
7279- const int kv_ntok = kv_self.head ;
7277+ const size_t kv_buf_size = kv_self.buf .size ;
7278+ const uint32_t kv_head = kv_self.head ;
7279+ const uint32_t kv_size = kv_self.size ;
72807280
7281- data_ctx->write (&kv_size, sizeof (kv_size));
7282- data_ctx->write (&kv_ntok, sizeof (kv_ntok));
7281+ data_ctx->write (&kv_buf_size, sizeof (kv_buf_size));
7282+ data_ctx->write (&kv_head, sizeof (kv_head));
7283+ data_ctx->write (&kv_size, sizeof (kv_size));
72837284
7284- if (kv_size ) {
7285+ if (kv_buf_size ) {
72857286 const size_t elt_size = ggml_element_size (kv_self.k );
72867287
72877288 ggml_context * cpy_ctx = ggml_init ({ 4096 , NULL , /* no_alloc */ true });
72887289 ggml_cgraph gf{};
72897290
7290- ggml_tensor * kout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_ntok , n_layer);
7291+ ggml_tensor * kout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_head , n_layer);
72917292 std::vector<uint8_t > kout3d_data (ggml_nbytes (kout3d), 0 );
72927293 kout3d->data = kout3d_data.data ();
72937294
7294- ggml_tensor * vout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_ntok , n_embd, n_layer);
7295+ ggml_tensor * vout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_head , n_embd, n_layer);
72957296 std::vector<uint8_t > vout3d_data (ggml_nbytes (vout3d), 0 );
72967297 vout3d->data = vout3d_data.data ();
72977298
72987299 ggml_tensor * k3d = ggml_view_3d (cpy_ctx, kv_self.k ,
7299- n_embd, kv_ntok , n_layer,
7300+ n_embd, kv_head , n_layer,
73007301 elt_size*n_embd, elt_size*n_embd*n_ctx, 0 );
73017302
73027303 ggml_tensor * v3d = ggml_view_3d (cpy_ctx, kv_self.v ,
7303- kv_ntok , n_embd, n_layer,
7304+ kv_head , n_embd, n_layer,
73047305 elt_size*n_ctx, elt_size*n_ctx*n_embd, 0 );
73057306
73067307 ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, k3d, kout3d));
@@ -7314,6 +7315,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
73147315 data_ctx->write (kout3d_data.data (), kout3d_data.size ());
73157316 data_ctx->write (vout3d_data.data (), vout3d_data.size ());
73167317 }
7318+
7319+ for (uint32_t i = 0 ; i < kv_size; ++i) {
7320+ const auto & cell = kv_self.cells [i];
7321+
7322+ const llama_pos pos = cell.pos ;
7323+ const size_t seq_id_size = cell.seq_id .size ();
7324+
7325+ data_ctx->write (&pos, sizeof (pos));
7326+ data_ctx->write (&seq_id_size, sizeof (seq_id_size));
7327+
7328+ for (auto seq_id : cell.seq_id ) {
7329+ data_ctx->write (&seq_id, sizeof (seq_id));
7330+ }
7331+ }
73177332 }
73187333}
73197334
@@ -7385,34 +7400,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
73857400 const int n_embd = hparams.n_embd_gqa ();
73867401 const int n_ctx = cparams.n_ctx ;
73877402
7388- size_t kv_size;
7389- int kv_ntok;
7403+ size_t kv_buf_size;
7404+ uint32_t kv_head;
7405+ uint32_t kv_size;
73907406
7391- memcpy (&kv_size, inp, sizeof (kv_size)); inp += sizeof (kv_size);
7392- memcpy (&kv_ntok, inp, sizeof (kv_ntok)); inp += sizeof (kv_ntok);
7407+ memcpy (&kv_buf_size, inp, sizeof (kv_buf_size)); inp += sizeof (kv_buf_size);
7408+ memcpy (&kv_head, inp, sizeof (kv_head)); inp += sizeof (kv_head);
7409+ memcpy (&kv_size, inp, sizeof (kv_size)); inp += sizeof (kv_size);
73937410
7394- if (kv_size ) {
7395- GGML_ASSERT (kv_self.buf .size == kv_size );
7411+ if (kv_buf_size ) {
7412+ GGML_ASSERT (kv_self.buf .size == kv_buf_size );
73967413
73977414 const size_t elt_size = ggml_element_size (kv_self.k );
73987415
73997416 ggml_context * cpy_ctx = ggml_init ({ 4096 , NULL , /* no_alloc */ true });
74007417 ggml_cgraph gf{};
74017418
7402- ggml_tensor * kin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_ntok , n_layer);
7419+ ggml_tensor * kin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_head , n_layer);
74037420 kin3d->data = (void *) inp;
74047421 inp += ggml_nbytes (kin3d);
74057422
7406- ggml_tensor * vin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_ntok , n_embd, n_layer);
7423+ ggml_tensor * vin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_head , n_embd, n_layer);
74077424 vin3d->data = (void *) inp;
74087425 inp += ggml_nbytes (vin3d);
74097426
74107427 ggml_tensor * k3d = ggml_view_3d (cpy_ctx, kv_self.k ,
7411- n_embd, kv_ntok , n_layer,
7428+ n_embd, kv_head , n_layer,
74127429 elt_size*n_embd, elt_size*n_embd*n_ctx, 0 );
74137430
74147431 ggml_tensor * v3d = ggml_view_3d (cpy_ctx, kv_self.v ,
7415- kv_ntok , n_embd, n_layer,
7432+ kv_head , n_embd, n_layer,
74167433 elt_size*n_ctx, elt_size*n_ctx*n_embd, 0 );
74177434
74187435 ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, kin3d, k3d));
@@ -7422,8 +7439,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
74227439 ggml_free (cpy_ctx);
74237440 }
74247441
7425- ctx->kv_self .head = kv_ntok ;
7442+ ctx->kv_self .head = kv_head ;
74267443 ctx->kv_self .size = kv_size;
7444+
7445+ ctx->kv_self .cells .resize (kv_size);
7446+
7447+ for (uint32_t i = 0 ; i < kv_size; ++i) {
7448+ llama_pos pos;
7449+ size_t seq_id_size;
7450+
7451+ memcpy (&pos, inp, sizeof (pos)); inp += sizeof (pos);
7452+ memcpy (&seq_id_size, inp, sizeof (seq_id_size)); inp += sizeof (seq_id_size);
7453+
7454+ ctx->kv_self .cells [i].pos = pos;
7455+
7456+ llama_seq_id seq_id;
7457+
7458+ for (size_t j = 0 ; j < seq_id_size; ++j) {
7459+ memcpy (&seq_id, inp, sizeof (seq_id)); inp += sizeof (seq_id);
7460+ ctx->kv_self .cells [i].seq_id .insert (seq_id);
7461+ }
7462+ }
74277463 }
74287464
74297465 const size_t nread = inp - src;
0 commit comments