From cc10fab4c0b4f5a26930d126ed796046b99d49a7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:33:30 +0300 Subject: [PATCH 1/6] graph : reuse hybrid graphs --- src/llama-graph.cpp | 41 ++++++++++++++++++++++++++++++++++--- src/llama-graph.h | 10 +++++++-- src/llama-memory-hybrid.cpp | 2 +- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8909bbfb95e..7e31a7ec6fe 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -461,8 +461,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + return res; } // @@ -1912,7 +1947,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index e9d387bd7c5..a61e0503423 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -365,22 +365,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e01b..a1b45e4a3cc 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } From 4f71c75fe5a40224e61523cd7b2ccfb9472ed27f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:36:17 +0300 Subject: [PATCH 2/6] graph : reuse recurrent graphs --- src/llama-graph.cpp | 15 +++++++++++++++ src/llama-graph.h | 2 ++ 2 files changed, 17 insertions(+) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7e31a7ec6fe..9931155d792 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -254,6 +254,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); diff --git a/src/llama-graph.h b/src/llama-graph.h index a61e0503423..03c483ffd73 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -225,6 +225,8 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph From 36a95e6f8edcb2958ece90ecd374c7361e677ab3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:44:41 +0300 Subject: [PATCH 3/6] graph : fix reuse check for recurrent inputs --- src/llama-graph.cpp | 11 ++++++++++- src/llama-graph.h | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9931155d792..098ed1640f4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -266,6 +266,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + return res; } @@ -512,6 +515,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + return res; } @@ -1891,6 +1897,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1959,7 +1968,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index 03c483ffd73..81ac329cc31 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -235,6 +235,10 @@ class llm_graph_input_rs : public llm_graph_input_i { ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { From 3aa4e3cd37def9b8568debb9c0f06f625579cd39 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:57:35 +0300 Subject: [PATCH 4/6] memory : move the recurrent state into the memory context --- src/llama-graph.cpp | 13 ++++++++----- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 ++++++++++------- src/llama-memory-recurrent.h | 6 ++++-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 098ed1640f4..1a794985417 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -238,6 +238,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : + mctx(mctx), + head(mctx->get_head()), + rs_z(mctx->get_rs_z()) { +} + void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -266,8 +272,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= head == mctx->get_head(); - res &= rs_z == mctx->get_rs_z(); + res &= this->head == mctx->get_head(); + res &= this->rs_z == mctx->get_rs_z(); return res; } @@ -1897,9 +1903,6 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); - inp->head = mctx_cur->get_head(); - inp->rs_z = mctx_cur->get_rs_z(); - return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 81ac329cc31..5fd383def99 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -220,7 +220,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} + llm_graph_input_rs(const llama_memory_recurrent_context * mctx); virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -236,9 +236,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // used in view offsets, need to match for valid graph reuse - uint32_t head; - int32_t rs_z; + // need to match for valid graph reuse + const uint32_t head; + const int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 812bf253049..bf6ae4db920 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1093,12 +1093,15 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), + n_rs(mem->size), head(0), rs_z(0), size(mem->size) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), + n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { +} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1139,19 +1142,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return is_full ? mem->size : mem->n; + return n_rs; } uint32_t llama_memory_recurrent_context::get_head() const { - return is_full ? 0 : mem->head; + return head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return is_full ? 0 : mem->rs_z; + return rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return mem->size; + return size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1163,5 +1166,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + return mem->cells[i + head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d73912..a2b19904fce 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,8 +175,10 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: - // TODO: extract all the state like `head` and `n` here // - const bool is_full = false; + const uint32_t n_rs = 0; + const uint32_t head = 0; + const int32_t rs_z = -1; + const uint32_t size = 0; }; From d24eb420b4e3e29b3e1b22d48b94557134c66b74 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 19:41:10 +0300 Subject: [PATCH 5/6] Revert "memory : move the recurrent state into the memory context" This reverts commit 00f115fe810815d4a22a6dee0acc346131e970e1. --- src/llama-graph.cpp | 13 +++++-------- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 +++++++---------- src/llama-memory-recurrent.h | 6 ++---- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1a794985417..098ed1640f4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -238,12 +238,6 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : - mctx(mctx), - head(mctx->get_head()), - rs_z(mctx->get_rs_z()) { -} - void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -272,8 +266,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= this->head == mctx->get_head(); - res &= this->rs_z == mctx->get_rs_z(); + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); return res; } @@ -1903,6 +1897,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 5fd383def99..81ac329cc31 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -220,7 +220,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx); + llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -236,9 +236,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // need to match for valid graph reuse - const uint32_t head; - const int32_t rs_z; + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index bf6ae4db920..812bf253049 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1093,15 +1093,12 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), - n_rs(mem->size), head(0), rs_z(0), size(mem->size) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), - n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { -} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1142,19 +1139,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return n_rs; + return is_full ? mem->size : mem->n; } uint32_t llama_memory_recurrent_context::get_head() const { - return head; + return is_full ? 0 : mem->head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return rs_z; + return is_full ? 0 : mem->rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return size; + return mem->size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1166,5 +1163,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + head].src0; + return mem->cells[i + mem->head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index a2b19904fce..47f01d73912 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,10 +175,8 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here // - const uint32_t n_rs = 0; - const uint32_t head = 0; - const int32_t rs_z = -1; - const uint32_t size = 0; + const bool is_full = false; }; From 454ab904547160f14011a1db923d21ef5e82bd7b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 15 Dec 2025 13:56:24 +0200 Subject: [PATCH 6/6] cont : fix build --- src/llama-graph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 098ed1640f4..10733f01ad1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -508,7 +508,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();