@@ -86,7 +86,7 @@ struct llama_model {
8686};
8787
8888//  load the model's weights from a file
89- bool  llama_model_load (const  std::string & fname, llama_model & model, gpt_vocab & vocab, int  n_ctx) {
89+ bool  llama_model_load (const  std::string & fname, llama_model & model, gpt_vocab & vocab, int  n_ctx, ggml_type memory_type = GGML_TYPE_F32 ) {
9090    fprintf (stderr, " %s: loading model from '%s' - please wait ...\n " c_str ());
9191
9292    std::vector<char > f_buf (1024 *1024 );
@@ -207,8 +207,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
207207        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); //  w2
208208        ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); //  w3
209209
210-         ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (GGML_TYPE_F16 ); //  memory_k
211-         ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (GGML_TYPE_F16 ); //  memory_v
210+         ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type ); //  memory_k
211+         ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type ); //  memory_v
212212
213213        ctx_size += (5  + 10 *n_layer)*256 ; //  object overhead
214214
@@ -293,8 +293,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
293293        const  int  n_mem      = n_layer*n_ctx;
294294        const  int  n_elements = n_embd*n_mem;
295295
296-         model.memory_k  = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
297-         model.memory_v  = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
296+         model.memory_k  = ggml_new_tensor_1d (ctx, memory_type , n_elements);
297+         model.memory_v  = ggml_new_tensor_1d (ctx, memory_type , n_elements);
298298
299299        const  size_t  memory_size = ggml_nbytes (model.memory_k ) + ggml_nbytes (model.memory_v );
300300
@@ -814,8 +814,9 @@ int main(int argc, char ** argv) {
814814
815815    //  load the model
816816    {
817+         const  ggml_type memory_type = params.memory_f16  ? GGML_TYPE_F16 : GGML_TYPE_F32;
817818        const  int64_t  t_start_us = ggml_time_us ();
818-         if  (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
819+         if  (!llama_model_load (params.model , model, vocab, params.n_ctx , memory_type )) {
819820            fprintf (stderr, " %s: failed to load model from '%s'\n " model .c_str ());
820821            return  1 ;
821822        }
0 commit comments