diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index a1b45e4a3cc..c80b67a0857 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -18,6 +18,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + bool is_iswa, /* recurrent */ ggml_type type_r, ggml_type type_s, @@ -29,24 +30,43 @@ llama_memory_hybrid::llama_memory_hybrid( /* layer filters */ const layer_filter_cb & filter_attn, const layer_filter_cb & filter_recr) : + is_iswa(is_iswa), hparams(model.hparams), - mem_attn(new llama_kv_cache( - model, - type_k, - type_v, - v_trans, - offload, - unified, - kv_size, - n_seq_max, - n_pad, - n_swa, - swa_type, - filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } - : filter_attn, - nullptr - )), + mem_attn(is_iswa + ? static_cast(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + unified, + kv_size, + n_seq_max, + n_pad, + n_swa, + swa_type, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )) : static_cast(new llama_kv_cache( + model, + type_k, + type_v, + v_trans, + offload, + unified, + kv_size, + n_seq_max, + n_pad, + n_swa, + swa_type, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )) + ), mem_recr(new llama_memory_recurrent( model, type_r, @@ -98,14 +118,30 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba } // prepare the attention cache - auto heads_attn = mem_attn->prepare(ubatches); - if (heads_attn.empty()) { - LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } + llama_kv_cache::slot_info_vec_t heads_attn; + llama_kv_cache::slot_info_vec_t heads_attn_iswa; + if (is_iswa) { + heads_attn = get_mem_attn_iswa()->get_base()->prepare(ubatches); + heads_attn_iswa = get_mem_attn_iswa()->get_swa()->prepare(ubatches); + if (heads_attn.empty() || heads_attn_iswa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(heads_attn), std::move(heads_attn_iswa), std::move(ubatches)); - return std::make_unique( + } else { + heads_attn = get_mem_attn()->prepare(ubatches); + if (heads_attn.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( this, std::move(heads_attn), std::move(ubatches)); + } + } while(false); return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); @@ -191,7 +227,13 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, } llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { - return mem_attn.get(); + GGML_ASSERT(!is_iswa && "llama_memory_hybrid::get_mem_attn: attention memory is not of type llama_kv_cache"); + return static_cast(mem_attn.get()); +} + +llama_kv_cache_iswa * llama_memory_hybrid::get_mem_attn_iswa() const { + GGML_ASSERT(is_iswa && "llama_memory_hybrid::get_mem_attn_iswa: attention memory is not of type llama_kv_cache_iswa"); + return static_cast(mem_attn.get()); } llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const { @@ -226,6 +268,19 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } +llama_memory_hybrid_context::llama_memory_hybrid_context( + llama_memory_hybrid * mem, + slot_info_vec_t sinfos_attn, + slot_info_vec_t sinfos_attn_iswa, + std::vector ubatches) : + is_iswa(true), + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn_iswa(), std::move(sinfos_attn), std::move(sinfos_attn_iswa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + bool llama_memory_hybrid_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); @@ -260,9 +315,15 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const { } const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const { + GGML_ASSERT(!is_iswa && "llama_memory_hybrid_context::get_attn: attention context is not of type llama_kv_cache_context"); return static_cast(ctx_attn.get()); } +const llama_kv_cache_iswa_context * llama_memory_hybrid_context::get_attn_iswa() const { + GGML_ASSERT(is_iswa && "llama_memory_hybrid_context::get_attn_iswa: attention context is not of type llama_kv_cache_iswa_context"); + return static_cast(ctx_attn.get()); +} + const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const { return static_cast(ctx_recr.get()); } diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 558cafdf984..c54010dd225 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -3,6 +3,7 @@ #include "llama-batch.h" #include "llama-graph.h" #include "llama-kv-cache.h" +#include "llama-kv-cache-iswa.h" #include "llama-memory.h" #include "llama-memory-recurrent.h" @@ -28,6 +29,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + bool is_iswa, /* recurrent */ ggml_type type_r, ggml_type type_s, @@ -80,12 +82,14 @@ class llama_memory_hybrid : public llama_memory_i { // llama_kv_cache * get_mem_attn() const; + llama_kv_cache_iswa * get_mem_attn_iswa() const; llama_memory_recurrent * get_mem_recr() const; private: + const bool is_iswa; const llama_hparams & hparams; - const std::unique_ptr mem_attn; + const std::unique_ptr mem_attn; const std::unique_ptr mem_recr; }; @@ -111,6 +115,12 @@ class llama_memory_hybrid_context : public llama_memory_context_i { slot_info_vec_t sinfos_attn, std::vector ubatches); + llama_memory_hybrid_context( + llama_memory_hybrid * mem, + slot_info_vec_t sinfos_attn, + slot_info_vec_t sinfos_attn_iswa, + std::vector ubatches); + ~llama_memory_hybrid_context() = default; bool next() override; @@ -124,9 +134,12 @@ class llama_memory_hybrid_context : public llama_memory_context_i { // const llama_kv_cache_context * get_attn() const; + const llama_kv_cache_iswa_context * get_attn_iswa() const; const llama_memory_recurrent_context * get_recr() const; private: + const bool is_iswa = false; + // the index of the next ubatch to process size_t i_next = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 04c48b5fd3f..c13788c190f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7410,6 +7410,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, + /* is_iswa */ hparams.is_swa_any(), /* recurrent_type_k */ GGML_TYPE_F32, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),