@@ -128,13 +128,12 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
128128//  default hparams (LLaMA 7B)
129129struct  llama_hparams  {
130130    uint32_t  n_vocab = 32000 ;
131-     uint32_t  n_vocab_sp  = 0 ;
131+     uint32_t  n_vocab_base  = 32000 ;
132132    uint32_t  n_ctx   = 512 ;   //  this is provided as user input?
133133    uint32_t  n_embd  = 4096 ;
134134    uint32_t  n_mult  = 256 ;
135135    uint32_t  n_head  = 32 ;
136136    uint32_t  n_layer = 32 ;
137-     uint32_t  n_rot   = 64 ;
138137    enum  llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
139138
140139    bool  operator !=(const  llama_hparams & other) const  {
@@ -460,7 +459,6 @@ enum llama_file_version {
460459    LLAMA_FILE_VERSION_GGJT_V1, //  added padding
461460    LLAMA_FILE_VERSION_GGJT_V2, //  changed quantization format
462461    LLAMA_FILE_VERSION_GGJT_V3, //  changed Q4 and Q8 quantization format
463-     LLAMA_FILE_VERSION_GGJT_V4, //  improved support for added/special tokens
464462};
465463
466464struct  llama_file_loader  {
@@ -476,6 +474,7 @@ struct llama_file_loader {
476474        read_hparams ();
477475        read_vocab ();
478476        read_tensor_metadata (file_idx, tensors_map);
477+         set_vocab_sp ();
479478    }
480479    void  read_magic () {
481480        uint32_t  magic = file.read_u32 ();
@@ -498,7 +497,6 @@ struct llama_file_loader {
498497                    case  1 : file_version = LLAMA_FILE_VERSION_GGJT_V1; return ;
499498                    case  2 : file_version = LLAMA_FILE_VERSION_GGJT_V2; return ;
500499                    case  3 : file_version = LLAMA_FILE_VERSION_GGJT_V3; return ;
501-                     case  4 : file_version = LLAMA_FILE_VERSION_GGJT_V4; return ;
502500                }
503501        }
504502
@@ -507,12 +505,12 @@ struct llama_file_loader {
507505    }
508506    void  read_hparams () {
509507        hparams.n_vocab  = file.read_u32 ();
510-         hparams.n_vocab_sp  = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32 () : 0 ;
511508        hparams.n_embd  = file.read_u32 ();
512509        hparams.n_mult  = file.read_u32 ();
513510        hparams.n_head  = file.read_u32 ();
514511        hparams.n_layer  = file.read_u32 ();
515-         hparams.n_rot  = file.read_u32 ();
512+         hparams.n_vocab_base  = file.read_u32 ();
513+         hparams.n_vocab_base  = (hparams.n_vocab_base  & 0xF0000000 ) == 0  ? hparams.n_vocab  : (hparams.n_vocab_base  & ~0xF0000000 ); //  this bitwise operation is necessary for compatibility with older models
516514        hparams.ftype  = (enum  llama_ftype) file.read_u32 ();
517515    }
518516    void  read_vocab () {
@@ -533,20 +531,6 @@ struct llama_file_loader {
533531            tok_score.tok  = std::move (word);
534532            tok_score.score  = score;
535533        }
536- 
537-         vocab.special_token_to_id .reserve (hparams.n_vocab_sp );
538- 
539-         for  (uint32_t  i = 0 ; i < hparams.n_vocab_sp ; i++) {
540-             llama_vocab::id token_id = file.read_u32 ();
541-             const  auto  & word = vocab.id_to_token [token_id].tok ;
542- 
543-             vocab.special_token_trie .add (word);
544-             vocab.special_token_to_id [word] = token_id;
545- 
546-             if  (vocab.max_special_token_length  < word.size ()) {
547-                 vocab.max_special_token_length  = word.size ();
548-             }
549-         }
550534    }
551535    void  read_tensor_metadata (size_t  file_idx, llama_load_tensors_map & tensors_map) {
552536        while  (file.tell () < file.size ) {
@@ -601,6 +585,24 @@ struct llama_file_loader {
601585            tensors_map.tensors .at (idx).shards .push_back (shard);
602586        }
603587    }
588+     void  set_vocab_sp () {
589+         uint32_t  vocab_sp = 3  + hparams.n_vocab  - hparams.n_vocab_base ;
590+         vocab.special_token_to_id .reserve (vocab_sp);
591+         for  (uint32_t  i = 0 ; i < vocab_sp; i++) {
592+             llama_vocab::id token_id = i > 2  ? hparams.n_vocab_base  + i : i;
593+             const  auto  & word = vocab.id_to_token [token_id].tok ;
594+             if  (word.empty ()) {
595+                 continue ;
596+             }
597+ 
598+             vocab.special_token_trie .add (word);
599+             vocab.special_token_to_id [word] = token_id;
600+ 
601+             if  (vocab.max_special_token_length  < word.size ()) {
602+                 vocab.max_special_token_length  = word.size ();
603+             }
604+         }
605+     }
604606};
605607
606608struct  llama_file_saver  {
@@ -620,12 +622,11 @@ struct llama_file_saver {
620622    void  write_hparams (enum  llama_ftype new_ftype) {
621623        const  llama_hparams & hparams = any_file_loader->hparams ;
622624        file.write_u32 (hparams.n_vocab );
623-         file.write_u32 (hparams.n_vocab_sp );
624625        file.write_u32 (hparams.n_embd );
625626        file.write_u32 (hparams.n_mult );
626627        file.write_u32 (hparams.n_head );
627628        file.write_u32 (hparams.n_layer );
628-         file.write_u32 (hparams.n_rot ); 
629+         file.write_u32 (hparams.n_vocab_base  |  0xF0000000 );  //  this bitwise operation is necessary for compatibility with older models 
629630        file.write_u32 (new_ftype);
630631    }
631632    void  write_vocab () {
@@ -639,9 +640,6 @@ struct llama_file_saver {
639640            file.write_raw (token_score.tok .data (), token_score.tok .size ());
640641            file.write_raw (&token_score.score , sizeof (token_score.score ));
641642        }
642-         for  (const  auto  & pair : any_file_loader->vocab .special_token_to_id ) {
643-             file.write_u32 (pair.second );
644-         }
645643    }
646644    void  write_tensor (llama_load_tensor & tensor, enum  ggml_type new_type, const  void  * new_data, size_t  new_size) {
647645        switch  (new_type) {
@@ -1015,8 +1013,7 @@ static const char *llama_file_version_name(llama_file_version version) {
10151013        case  LLAMA_FILE_VERSION_GGMF_V1: return  " ggmf v1 (old version with no mmap support)" 
10161014        case  LLAMA_FILE_VERSION_GGJT_V1: return  " ggjt v1 (pre #1405)" 
10171015        case  LLAMA_FILE_VERSION_GGJT_V2: return  " ggjt v2 (pre #1508)" 
1018-         case  LLAMA_FILE_VERSION_GGJT_V3: return  " ggjt v3 (pre #1931)" 
1019-         case  LLAMA_FILE_VERSION_GGJT_V4: return  " ggjt v4 (latest)" 
1016+         case  LLAMA_FILE_VERSION_GGJT_V3: return  " ggjt v3 (latest)" 
10201017    }
10211018
10221019    return  " unknown" 
@@ -1113,7 +1110,7 @@ static void llama_model_load_internal(
11131110        fprintf (stderr, " %s: n_mult     = %u\n " n_mult );
11141111        fprintf (stderr, " %s: n_head     = %u\n " n_head );
11151112        fprintf (stderr, " %s: n_layer    = %u\n " n_layer );
1116-         fprintf (stderr, " %s: n_rot      = %u\n " n_rot );
1113+         fprintf (stderr, " %s: n_rot      = %u\n " n_embd /hparams. n_head );
11171114        fprintf (stderr, " %s: ftype      = %u (%s)\n " ftype , llama_ftype_name (hparams.ftype ));
11181115        fprintf (stderr, " %s: n_ff       = %u\n " 
11191116        fprintf (stderr, " %s: n_parts    = %zu\n " file_loaders .size ());
0 commit comments