@@ -371,31 +371,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
371371//  along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
372372//    these are used by the llama_context to extact the relevant data, based on the compute parameters
373373
374- //  TODO: this interface seems redundant - remove it
375- class  llm_graph_result_i  {
376- public: 
377-     virtual  ~llm_graph_result_i () = default ;
378- 
379-     virtual  ggml_tensor * get_tokens ()      const  = 0;
380-     virtual  ggml_tensor * get_logits ()      const  = 0;
381-     virtual  ggml_tensor * get_embd ()        const  = 0;
382-     virtual  ggml_tensor * get_embd_pooled () const  = 0;
383- 
384-     virtual  ggml_cgraph  * get_gf ()  = 0;
385-     virtual  ggml_context * get_ctx () = 0;
386- 
387-     virtual  void  reset () = 0;
388- 
389-     virtual  void  set_inputs (const  llama_ubatch * ubatch) = 0;
390- 
391-     virtual  bool  can_reuse (const  llm_graph_params & params) = 0;
392- };
393- 
394- using  llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
395- 
396374//  callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
397375using  llm_graph_cb = std::function<void (const  llama_ubatch & ubatch, ggml_tensor * cur, const  char  * name, int  il)>;
398376
377+ class  llm_graph_result ;
378+ 
399379struct  llm_graph_params  {
400380    llm_arch arch = LLM_ARCH_UNKNOWN;
401381
@@ -418,8 +398,7 @@ struct llm_graph_params {
418398
419399    llm_graph_cb cb;
420400
421-     //  TODO: temporary
422-     llm_graph_result_i * res;
401+     llm_graph_result * res;
423402
424403    //  return true if the "other" params would result in a graph with the same topology as with the current params
425404    //    having the same topology allows us to reuse the graph in some cases
@@ -462,27 +441,27 @@ struct llm_graph_params {
462441    }
463442};
464443
465- class  llm_graph_result  :  public   llm_graph_result_i   {
444+ class  llm_graph_result  {
466445public: 
467446    llm_graph_result (int64_t  max_nodes) : max_nodes(max_nodes) {
468447        reset ();
469448    }
470449
471450    virtual  ~llm_graph_result () = default ;
472451
473-     ggml_tensor * get_tokens ()      const  override   { return  t_tokens; }
474-     ggml_tensor * get_logits ()      const  override   { return  t_logits; }
475-     ggml_tensor * get_embd ()        const  override   { return  t_embd; }
476-     ggml_tensor * get_embd_pooled () const  override   { return  t_embd_pooled; }
452+     ggml_tensor * get_tokens ()      const  { return  t_tokens; }
453+     ggml_tensor * get_logits ()      const  { return  t_logits; }
454+     ggml_tensor * get_embd ()        const  { return  t_embd; }
455+     ggml_tensor * get_embd_pooled () const  { return  t_embd_pooled; }
477456
478-     ggml_cgraph  * get_gf ()  override   { return  gf; }
479-     ggml_context * get_ctx () override   { return  ctx_compute.get (); }
457+     ggml_cgraph  * get_gf ()  { return  gf; }
458+     ggml_context * get_ctx () { return  ctx_compute.get (); }
480459
481460    void  set_max_nodes (int64_t  max_nodes) {
482461        this ->max_nodes  = max_nodes;
483462    }
484463
485-     void  reset () override   {
464+     void  reset () {
486465        t_tokens      = nullptr ;
487466        t_logits      = nullptr ;
488467        t_embd        = nullptr ;
@@ -503,7 +482,7 @@ class llm_graph_result : public llm_graph_result_i {
503482        gf = ggml_new_graph_custom (ctx_compute.get (), max_nodes, false );
504483    }
505484
506-     void  set_inputs (const  llama_ubatch * ubatch) override   {
485+     void  set_inputs (const  llama_ubatch * ubatch) {
507486        for  (auto  & input : inputs) {
508487            input->set_input (ubatch);
509488        }
@@ -514,7 +493,7 @@ class llm_graph_result : public llm_graph_result_i {
514493    //    would be identical to the existing graph. in that case, we simply have to update the memory
515494    //    contexts of the input tensors of the graph and we can reuse it for another computation
516495    //  return true if the graph was updated and can be reused
517-     bool  can_reuse (const  llm_graph_params & params) override   {
496+     bool  can_reuse (const  llm_graph_params & params) {
518497        if  (!this ->params .allow_reuse (params)) {
519498            return  false ;
520499        }
@@ -533,6 +512,10 @@ class llm_graph_result : public llm_graph_result_i {
533512        return  inputs.back ().get ();
534513    }
535514
515+     void  set_params (const  llm_graph_params & params) {
516+         this ->params  = params;
517+     }
518+ 
536519    //  important graph nodes
537520    ggml_tensor * t_tokens      = nullptr ;
538521    ggml_tensor * t_logits      = nullptr ;
@@ -550,12 +533,15 @@ class llm_graph_result : public llm_graph_result_i {
550533
551534    int64_t  max_nodes;
552535
536+ private: 
553537    //  keep a copy of the previous graph parameters
554538    //  we will use this to determine whether the graph can be reused by comparing them with the new parameters
555539    //  note: these are updated after constructing the new graph
556540    llm_graph_params params;
557541};
558542
543+ using  llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
544+ 
559545// 
560546//  llm_graph_context
561547// 
@@ -613,6 +599,7 @@ struct llm_graph_context {
613599    llm_graph_result * res;
614600
615601    ggml_context * ctx0 = nullptr ;
602+     ggml_cgraph  * gf   = nullptr ;
616603
617604    llm_graph_context (const  llm_graph_params & params);
618605    virtual  ~llm_graph_context () = default ;
@@ -698,7 +685,6 @@ struct llm_graph_context {
698685    // 
699686
700687    ggml_tensor * build_attn_mha (
701-              ggml_cgraph * gf,
702688             ggml_tensor * q,       //  [n_embd_head_q, n_head_q, n_tokens]
703689             ggml_tensor * k,       //  [n_embd_head_k, n_head_k, n_tokens]
704690             ggml_tensor * v,       //  [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -711,7 +697,6 @@ struct llm_graph_context {
711697
712698    ggml_tensor * build_attn (
713699            llm_graph_input_attn_no_cache * inp,
714-             ggml_cgraph * gf,
715700            ggml_tensor * wo,
716701            ggml_tensor * wo_b,
717702            ggml_tensor * q_cur, //  [n_embd_head_q, n_head_q, n_tokens]
@@ -726,7 +711,6 @@ struct llm_graph_context {
726711
727712    ggml_tensor * build_attn (
728713            llm_graph_input_attn_kv_unified * inp,
729-             ggml_cgraph * gf,
730714            ggml_tensor * wo,
731715            ggml_tensor * wo_b,
732716            ggml_tensor * q_cur, //  [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +726,6 @@ struct llm_graph_context {
742726    //  note: if k_cur or v_cur are not provided, they will not be stored in the memory
743727    ggml_tensor * build_attn (
744728            llm_graph_input_attn_kv_unified_iswa * inp,
745-             ggml_cgraph * gf,
746729            ggml_tensor * wo,
747730            ggml_tensor * wo_b,
748731            ggml_tensor * q_cur, //  [n_embd_head_q, n_head_q, n_tokens]
@@ -757,7 +740,6 @@ struct llm_graph_context {
757740
758741    ggml_tensor * build_attn (
759742            llm_graph_input_attn_cross * inp,
760-             ggml_cgraph * gf,
761743            ggml_tensor * wo,
762744            ggml_tensor * wo_b,
763745            ggml_tensor * q_cur, //  [n_embd_head_q, n_head_q, n_tokens]
@@ -779,7 +761,6 @@ struct llm_graph_context {
779761    //          implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
780762    //          `llama_memory_recurrent`
781763    ggml_tensor * build_rs (
782-             ggml_cgraph * gf,
783764            ggml_tensor * s,
784765            ggml_tensor * state_copy,
785766                int32_t    state_size,
@@ -794,17 +775,15 @@ struct llm_graph_context {
794775
795776    ggml_tensor * build_rs (
796777            llm_graph_input_rs * inp,
797-             ggml_cgraph * gf,
798778            ggml_tensor * s,
799779                int32_t    state_size,
800780                int32_t    n_seqs,
801781            const  llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const ;
802782
803783    ggml_tensor * build_rwkv_token_shift_load (
804784        llm_graph_input_rs * inp,
805-                ggml_cgraph * gf,
806785        const  llama_ubatch & ubatch,
807-                      int    il) const ;
786+                         int    il) const ;
808787
809788    ggml_tensor * build_rwkv_token_shift_store (
810789             ggml_tensor * token_shift,
@@ -821,7 +800,6 @@ struct llm_graph_context {
821800    // 
822801
823802    void  build_pooling (
824-             ggml_cgraph * gf,
825803            ggml_tensor * cls,
826804            ggml_tensor * cls_b,
827805            ggml_tensor * cls_out,
0 commit comments