Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 85 additions & 24 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ llama_memory_hybrid::llama_memory_hybrid(
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
bool is_iswa,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
Expand All @@ -29,24 +30,43 @@ llama_memory_hybrid::llama_memory_hybrid(
/* layer filters */
const layer_filter_cb & filter_attn,
const layer_filter_cb & filter_recr) :
is_iswa(is_iswa),
hparams(model.hparams),
mem_attn(new llama_kv_cache(
model,
type_k,
type_v,
v_trans,
offload,
unified,
kv_size,
n_seq_max,
n_pad,
n_swa,
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)),
mem_attn(is_iswa
? static_cast<llama_memory_i *>(new llama_kv_cache_iswa(
model,
type_k,
type_v,
v_trans,
offload,
unified,
kv_size,
n_seq_max,
n_pad,
n_swa,
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)) : static_cast<llama_memory_i *>(new llama_kv_cache(
model,
type_k,
type_v,
v_trans,
offload,
unified,
kv_size,
n_seq_max,
n_pad,
n_swa,
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
))
),
mem_recr(new llama_memory_recurrent(
model,
type_r,
Expand Down Expand Up @@ -98,14 +118,30 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
}

// prepare the attention cache
auto heads_attn = mem_attn->prepare(ubatches);
if (heads_attn.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_kv_cache::slot_info_vec_t heads_attn;
llama_kv_cache::slot_info_vec_t heads_attn_iswa;
if (is_iswa) {
heads_attn = get_mem_attn_iswa()->get_base()->prepare(ubatches);
heads_attn_iswa = get_mem_attn_iswa()->get_swa()->prepare(ubatches);
if (heads_attn.empty() || heads_attn_iswa.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}

return std::make_unique<llama_memory_hybrid_context>(
this, std::move(heads_attn), std::move(heads_attn_iswa), std::move(ubatches));

return std::make_unique<llama_memory_hybrid_context>(
} else {
heads_attn = get_mem_attn()->prepare(ubatches);
if (heads_attn.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}

return std::make_unique<llama_memory_hybrid_context>(
this, std::move(heads_attn), std::move(ubatches));
}

} while(false);

return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
Expand Down Expand Up @@ -191,7 +227,13 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
}

llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
return mem_attn.get();
GGML_ASSERT(!is_iswa && "llama_memory_hybrid::get_mem_attn: attention memory is not of type llama_kv_cache");
return static_cast<llama_kv_cache *>(mem_attn.get());
}

llama_kv_cache_iswa * llama_memory_hybrid::get_mem_attn_iswa() const {
GGML_ASSERT(is_iswa && "llama_memory_hybrid::get_mem_attn_iswa: attention memory is not of type llama_kv_cache_iswa");
return static_cast<llama_kv_cache_iswa *>(mem_attn.get());
}

llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
Expand Down Expand Up @@ -226,6 +268,19 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}

llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem,
slot_info_vec_t sinfos_attn,
slot_info_vec_t sinfos_attn_iswa,
std::vector<llama_ubatch> ubatches) :
is_iswa(true),
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn_iswa(), std::move(sinfos_attn), std::move(sinfos_attn_iswa), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}

bool llama_memory_hybrid_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

Expand Down Expand Up @@ -260,9 +315,15 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
}

const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
GGML_ASSERT(!is_iswa && "llama_memory_hybrid_context::get_attn: attention context is not of type llama_kv_cache_context");
return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
}

const llama_kv_cache_iswa_context * llama_memory_hybrid_context::get_attn_iswa() const {
GGML_ASSERT(is_iswa && "llama_memory_hybrid_context::get_attn_iswa: attention context is not of type llama_kv_cache_iswa_context");
return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
}

const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
}
15 changes: 14 additions & 1 deletion src/llama-memory-hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory.h"
#include "llama-memory-recurrent.h"

Expand All @@ -28,6 +29,7 @@ class llama_memory_hybrid : public llama_memory_i {
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
bool is_iswa,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
Expand Down Expand Up @@ -80,12 +82,14 @@ class llama_memory_hybrid : public llama_memory_i {
//

llama_kv_cache * get_mem_attn() const;
llama_kv_cache_iswa * get_mem_attn_iswa() const;
llama_memory_recurrent * get_mem_recr() const;

private:
const bool is_iswa;
const llama_hparams & hparams;

const std::unique_ptr<llama_kv_cache> mem_attn;
const std::unique_ptr<llama_memory_i> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};

Expand All @@ -111,6 +115,12 @@ class llama_memory_hybrid_context : public llama_memory_context_i {
slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches);

llama_memory_hybrid_context(
llama_memory_hybrid * mem,
slot_info_vec_t sinfos_attn,
slot_info_vec_t sinfos_attn_iswa,
std::vector<llama_ubatch> ubatches);

~llama_memory_hybrid_context() = default;

bool next() override;
Expand All @@ -124,9 +134,12 @@ class llama_memory_hybrid_context : public llama_memory_context_i {
//

const llama_kv_cache_context * get_attn() const;
const llama_kv_cache_iswa_context * get_attn_iswa() const;
const llama_memory_recurrent_context * get_recr() const;

private:
const bool is_iswa = false;

// the index of the next ubatch to process
size_t i_next = 0;

Expand Down
1 change: 1 addition & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7410,6 +7410,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* is_iswa */ hparams.is_swa_any(),
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
Expand Down
Loading