@@ -181,14 +181,13 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
181181// default hparams (LLaMA 7B)
182182struct llama_hparams {
183183 uint32_t n_vocab = 32000 ;
184- uint32_t n_vocab_sp = 0 ;
184+ uint32_t n_vocab_base = 32000 ;
185185 uint32_t n_ctx = 512 ; // this is provided as user input?
186186 uint32_t n_embd = 4096 ;
187187 uint32_t n_mult = 256 ;
188188 uint32_t n_head = 32 ;
189189 uint32_t n_head_kv = 32 ;
190190 uint32_t n_layer = 32 ;
191- uint32_t n_rot = 64 ;
192191
193192 // LLaMAv2
194193 // TODO: load from model data hparams
@@ -499,7 +498,6 @@ enum llama_file_version {
499498 LLAMA_FILE_VERSION_GGJT_V1, // added padding
500499 LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
501500 LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format
502- LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens
503501};
504502
505503struct llama_file_loader {
@@ -515,6 +513,7 @@ struct llama_file_loader {
515513 read_hparams ();
516514 read_vocab ();
517515 read_tensor_metadata (tensors_map);
516+ set_vocab_sp ();
518517 }
519518 void read_magic () {
520519 uint32_t magic = file.read_u32 ();
@@ -537,7 +536,6 @@ struct llama_file_loader {
537536 case 1 : file_version = LLAMA_FILE_VERSION_GGJT_V1; return ;
538537 case 2 : file_version = LLAMA_FILE_VERSION_GGJT_V2; return ;
539538 case 3 : file_version = LLAMA_FILE_VERSION_GGJT_V3; return ;
540- case 4 : file_version = LLAMA_FILE_VERSION_GGJT_V4; return ;
541539 }
542540 }
543541
@@ -546,18 +544,18 @@ struct llama_file_loader {
546544 }
547545 void read_hparams () {
548546 hparams.n_vocab = file.read_u32 ();
549- hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32 () : 0 ;
550547 hparams.n_embd = file.read_u32 ();
551548 hparams.n_mult = file.read_u32 ();
552549 hparams.n_head = file.read_u32 ();
553550 hparams.n_layer = file.read_u32 ();
554- hparams.n_rot = file.read_u32 ();
551+ hparams.n_vocab_base = file.read_u32 ();
552+ hparams.n_vocab_base = (hparams.n_vocab_base & 0xF0000000 ) == 0 ? hparams.n_vocab : (hparams.n_vocab_base & ~0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
555553 hparams.ftype = (enum llama_ftype) file.read_u32 ();
556554
557555 // LLaMAv2
558556 // TODO: read from header
559557 hparams.n_head_kv = hparams.n_head ;
560- }
558+ =======
561559 void read_vocab () {
562560 vocab.id_to_token .resize (hparams.n_vocab );
563561
@@ -574,20 +572,6 @@ struct llama_file_loader {
574572 tok_score.tok = std::move (word);
575573 tok_score.score = score;
576574 }
577-
578- vocab.special_token_to_id .reserve (hparams.n_vocab_sp );
579-
580- for (uint32_t i = 0 ; i < hparams.n_vocab_sp ; i++) {
581- llama_vocab::id token_id = file.read_u32 ();
582- const auto & word = vocab.id_to_token [token_id].tok ;
583-
584- vocab.special_token_trie .add (word);
585- vocab.special_token_to_id [word] = token_id;
586-
587- if (vocab.max_special_token_length < word.size ()) {
588- vocab.max_special_token_length = word.size ();
589- }
590- }
591575 }
592576 void read_tensor_metadata (llama_load_tensors_map & tensors_map) {
593577 while (file.tell () < file.size ) {
@@ -634,6 +618,24 @@ struct llama_file_loader {
634618 tensors_map.name_to_idx [name] = tensors_map.tensors .size () - 1 ;
635619 }
636620 }
621+ void set_vocab_sp () {
622+ uint32_t vocab_sp = 3 + hparams.n_vocab - hparams.n_vocab_base ;
623+ vocab.special_token_to_id .reserve (vocab_sp);
624+ for (uint32_t i = 0 ; i < vocab_sp; i++) {
625+ llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
626+ const auto & word = vocab.id_to_token [token_id].tok ;
627+ if (word.empty ()) {
628+ continue ;
629+ }
630+
631+ vocab.special_token_trie .add (word);
632+ vocab.special_token_to_id [word] = token_id;
633+
634+ if (vocab.max_special_token_length < word.size ()) {
635+ vocab.max_special_token_length = word.size ();
636+ }
637+ }
638+ }
637639};
638640
639641struct llama_file_saver {
@@ -653,12 +655,11 @@ struct llama_file_saver {
653655 void write_hparams (enum llama_ftype new_ftype) {
654656 const llama_hparams & hparams = any_file_loader->hparams ;
655657 file.write_u32 (hparams.n_vocab );
656- file.write_u32 (hparams.n_vocab_sp );
657658 file.write_u32 (hparams.n_embd );
658659 file.write_u32 (hparams.n_mult );
659660 file.write_u32 (hparams.n_head );
660661 file.write_u32 (hparams.n_layer );
661- file.write_u32 (hparams.n_rot );
662+ file.write_u32 (hparams.n_vocab_base | 0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
662663 file.write_u32 (new_ftype);
663664 }
664665 void write_vocab () {
@@ -672,9 +673,6 @@ struct llama_file_saver {
672673 file.write_raw (token_score.tok .data (), token_score.tok .size ());
673674 file.write_raw (&token_score.score , sizeof (token_score.score ));
674675 }
675- for (const auto & pair : any_file_loader->vocab .special_token_to_id ) {
676- file.write_u32 (pair.second );
677- }
678676 }
679677 void write_tensor (llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
680678 switch (new_type) {
@@ -1001,8 +999,7 @@ static const char *llama_file_version_name(llama_file_version version) {
1001999 case LLAMA_FILE_VERSION_GGMF_V1: return " ggmf v1 (old version with no mmap support)" ;
10021000 case LLAMA_FILE_VERSION_GGJT_V1: return " ggjt v1 (pre #1405)" ;
10031001 case LLAMA_FILE_VERSION_GGJT_V2: return " ggjt v2 (pre #1508)" ;
1004- case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (pre #1931)" ;
1005- case LLAMA_FILE_VERSION_GGJT_V4: return " ggjt v4 (latest)" ;
1002+ case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (latest)" ;
10061003 }
10071004
10081005 return " unknown" ;
@@ -1127,7 +1124,7 @@ static void llama_model_load_internal(
11271124 fprintf (stderr, " %s: n_head = %u\n " , __func__, hparams.n_head );
11281125 fprintf (stderr, " %s: n_head_kv = %u\n " , __func__, hparams.n_head_kv );
11291126 fprintf (stderr, " %s: n_layer = %u\n " , __func__, hparams.n_layer );
1130- fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_rot ); // a.k.a. n_embd_head, n_head_dim
1127+ fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_embd /hparams. n_head ); // a.k.a. n_embd_head, n_head_dim
11311128 fprintf (stderr, " %s: n_gqa = %u\n " , __func__, hparams.n_gqa ());
11321129 fprintf (stderr, " %s: rnorm_eps = %.1e\n " , __func__, hparams.f_rms_norm_eps );
11331130 fprintf (stderr, " %s: n_ff = %u\n " , __func__, n_ff);
0 commit comments