@@ -37,13 +37,16 @@ extern "C" {
3737    // ====== Dataset ====== 
3838
3939    GGML_API  ggml_opt_dataset_t  ggml_opt_dataset_init (
40-             int64_t  ne_datapoint , // number of elements per datapoint 
41-             int64_t  ne_label ,     // number of elements per label 
42-             int64_t  ndata ,        // total number of datapoints/labels 
43-             int64_t  ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) 
40+             enum  ggml_type  type_data ,    // the type for the internal data tensor 
41+             enum  ggml_type  type_label ,   // the type for the internal labels tensor 
42+             int64_t         ne_datapoint , // number of elements per datapoint 
43+             int64_t         ne_label ,     // number of elements per label 
44+             int64_t         ndata ,        // total number of datapoints/labels 
45+             int64_t         ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) 
4446    GGML_API  void  ggml_opt_dataset_free (ggml_opt_dataset_t  dataset );
4547
4648    // get underlying tensors that store the data 
49+     GGML_API  int64_t               ggml_opt_dataset_ndata  (ggml_opt_dataset_t  dataset );
4750    GGML_API  struct  ggml_tensor  *  ggml_opt_dataset_data   (ggml_opt_dataset_t  dataset ); // shape = [ne_datapoint, ndata] 
4851    GGML_API  struct  ggml_tensor  *  ggml_opt_dataset_labels (ggml_opt_dataset_t  dataset ); // shape = [nd_label,     ndata] 
4952
@@ -56,13 +59,19 @@ extern "C" {
5659            struct  ggml_tensor  *  data_batch ,   // shape = [ne_datapoint, ndata_batch] 
5760            struct  ggml_tensor  *  labels_batch , // shape = [ne_label,     ndata_batch] 
5861            int64_t               ibatch );
62+     GGML_API  void  ggml_opt_dataset_get_batch_host (
63+             ggml_opt_dataset_t    dataset ,
64+             void                *  data_batch ,
65+             size_t                nb_data_batch ,
66+             void                *  labels_batch ,
67+             int64_t               ibatch );
5968
6069    // ====== Model / Context ====== 
6170
6271    enum  ggml_opt_build_type  {
63-         GGML_OPT_BUILD_TYPE_FORWARD ,
64-         GGML_OPT_BUILD_TYPE_GRAD ,
65-         GGML_OPT_BUILD_TYPE_OPT ,
72+         GGML_OPT_BUILD_TYPE_FORWARD   =   10 ,
73+         GGML_OPT_BUILD_TYPE_GRAD      =   20 ,
74+         GGML_OPT_BUILD_TYPE_OPT       =   30 ,
6675    };
6776
6877    // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss 
@@ -81,20 +90,22 @@ extern "C" {
8190    // userdata can be used to pass arbitrary data 
8291    typedef  struct  ggml_opt_optimizer_params  (* ggml_opt_get_optimizer_params )(void  *  userdata );
8392
84-     // returns the default optimizer params (constant) 
93+     // returns the default optimizer params (constant, hard-coded values ) 
8594    // userdata is not used 
8695    GGML_API  struct  ggml_opt_optimizer_params  ggml_opt_get_default_optimizer_params (void  *  userdata );
8796
97+     // casts userdata to ggml_opt_optimizer_params and returns it 
98+     GGML_API  struct  ggml_opt_optimizer_params  ggml_opt_get_constant_optimizer_params (void  *  userdata );
99+ 
88100    // parameters for initializing a new optimization context 
89101    struct  ggml_opt_params  {
90102        ggml_backend_sched_t  backend_sched ; // defines which backends are used to construct the compute graphs 
91103
92-         struct  ggml_context  *  ctx_compute ; // created in user code, holds non-static tensors 
93- 
94-         // the forward graph is defined by inputs and outputs 
95-         // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts 
96-         struct  ggml_tensor  *  inputs ;
97-         struct  ggml_tensor  *  outputs ;
104+         // by default the forward graph needs to be reconstructed for each eval 
105+         // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically 
106+         struct  ggml_context  *  ctx_compute ;
107+         struct  ggml_tensor   *  inputs ;
108+         struct  ggml_tensor   *  outputs ;
98109
99110        enum  ggml_opt_loss_type   loss_type ;
100111        enum  ggml_opt_build_type  build_type ;
@@ -107,12 +118,9 @@ extern "C" {
107118
108119    // get parameters for an optimization context with defaults set where possible 
109120    // parameters for which no sensible defaults exist are supplied as arguments to this function 
110-     GGML_API  ggml_opt_params  ggml_opt_default_params (
111-             ggml_backend_sched_t       backend_sched ,
112-             struct  ggml_context      *  ctx_compute ,
113-             struct  ggml_tensor       *  inputs ,
114-             struct  ggml_tensor       *  outputs ,
115-             enum  ggml_opt_loss_type    loss_type );
121+     GGML_API  struct  ggml_opt_params  ggml_opt_default_params (
122+             ggml_backend_sched_t     backend_sched ,
123+             enum  ggml_opt_loss_type  loss_type );
116124
117125    GGML_API  ggml_opt_context_t  ggml_opt_init (struct  ggml_opt_params  params );
118126    GGML_API  void  ggml_opt_free (ggml_opt_context_t  opt_ctx );
@@ -121,18 +129,20 @@ extern "C" {
121129    GGML_API  void  ggml_opt_reset (ggml_opt_context_t  opt_ctx , bool  optimizer );
122130
123131    // get underlying tensors that store data 
132+     // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc 
124133    GGML_API  struct  ggml_tensor  *  ggml_opt_inputs (  ggml_opt_context_t  opt_ctx ); // forward graph input tensor 
125134    GGML_API  struct  ggml_tensor  *  ggml_opt_outputs ( ggml_opt_context_t  opt_ctx ); // forward graph output tensor 
126135    GGML_API  struct  ggml_tensor  *  ggml_opt_labels (  ggml_opt_context_t  opt_ctx ); // labels to compare outputs against 
127136    GGML_API  struct  ggml_tensor  *  ggml_opt_loss (    ggml_opt_context_t  opt_ctx ); // scalar tensor that contains the loss 
128137    GGML_API  struct  ggml_tensor  *  ggml_opt_pred (    ggml_opt_context_t  opt_ctx ); // predictions made by outputs 
129138    GGML_API  struct  ggml_tensor  *  ggml_opt_ncorrect (ggml_opt_context_t  opt_ctx ); // number of matching predictions between outputs and labels 
130139
140+     // get the gradient accumulator for a node from the forward graph 
131141    GGML_API  struct  ggml_tensor  *  ggml_opt_grad_acc (ggml_opt_context_t  opt_ctx , struct  ggml_tensor  *  node );
132142
133143    // ====== Optimization Result ====== 
134144
135-     GGML_API  ggml_opt_result_t  ggml_opt_result_init ();
145+     GGML_API  ggml_opt_result_t  ggml_opt_result_init (void );
136146    GGML_API  void  ggml_opt_result_free (ggml_opt_result_t  result );
137147    GGML_API  void  ggml_opt_result_reset (ggml_opt_result_t  result );
138148
@@ -144,11 +154,20 @@ extern "C" {
144154
145155    // ====== Computation ====== 
146156
147-     // do forward pass, increment result if not NULL 
148-     GGML_API  void  ggml_opt_forward (ggml_opt_context_t  opt_ctx , ggml_opt_result_t  result );
157+     // if not using static graphs, this function must be called prior to ggml_opt_alloc 
158+     GGML_API  void  ggml_opt_prepare_alloc (
159+         ggml_opt_context_t     opt_ctx ,
160+         struct  ggml_context  *  ctx_compute ,
161+         struct  ggml_cgraph   *  gf ,
162+         struct  ggml_tensor   *  inputs ,
163+         struct  ggml_tensor   *  outputs );
164+ 
165+     // allocate the next graph for evaluation, either forward or forward + backward 
166+     // must be called exactly once prior to calling ggml_opt_eval 
167+     GGML_API  void  ggml_opt_alloc (ggml_opt_context_t  opt_ctx , bool  backward );
149168
150-     // do forward pass, increment result if not NULL, do backward pass 
151-     GGML_API  void  ggml_opt_forward_backward (ggml_opt_context_t  opt_ctx , ggml_opt_result_t  result );
169+     // do forward pass, increment result if not NULL, do backward pass if allocated  
170+     GGML_API  void  ggml_opt_eval (ggml_opt_context_t  opt_ctx , ggml_opt_result_t  result );
152171
153172    // ############################################################################ 
154173    // ## The high-level functions start here. They do not depend on any private ## 
@@ -200,9 +219,9 @@ extern "C" {
200219    // fit model defined by inputs and outputs to dataset 
201220    GGML_API  void  ggml_opt_fit (
202221            ggml_backend_sched_t             backend_sched ,  // backend scheduler for constructing the compute graphs 
203-             ggml_context                    *  ctx_compute ,    // context with temporarily allocated tensors to calculate the outputs 
204-             ggml_tensor                     *  inputs ,         // input tensor with shape [ne_datapoint, ndata_batch] 
205-             ggml_tensor                     *  outputs ,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used 
222+             struct   ggml_context            *  ctx_compute ,    // context with temporarily allocated tensors to calculate the outputs 
223+             struct   ggml_tensor             *  inputs ,         // input tensor with shape [ne_datapoint, ndata_batch] 
224+             struct   ggml_tensor             *  outputs ,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used 
206225            ggml_opt_dataset_t               dataset ,        // dataset with data and optionally also labels 
207226            enum  ggml_opt_loss_type          loss_type ,      // loss to minimize 
208227            ggml_opt_get_optimizer_params    get_opt_pars ,   // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) 
0 commit comments