@@ -908,6 +908,40 @@ def _set_vocab_llama_hf(self):
908908        special_vocab  =  gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
909909        special_vocab .add_to_gguf (self .gguf_writer )
910910
911+     def  _set_vocab_rwkv_world (self ):
912+         assert  (self .dir_model  /  "rwkv_vocab_v20230424.txt" ).is_file ()
913+         vocab_size  =  self .hparams .get ("vocab_size" , 65536 )
914+ 
915+         tokens : list [bytes ] =  ['<s>' .encode ("utf-8" )]
916+         toktypes : list [int ] =  [gguf .TokenType .CONTROL ]
917+ 
918+         with  open (self .dir_model  /  "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as  f :
919+             lines  =  f .readlines ()
920+             for  line  in  lines :
921+                 parts  =  line .split (' ' )
922+                 assert  len (parts ) >=  3 
923+                 token , token_len  =  ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
924+                 token  =  token .encode ("utf-8" ) if  isinstance (token , str ) else  token 
925+                 assert  isinstance (token , bytes )
926+                 assert  len (token ) ==  token_len 
927+                 token_text : str  =  repr (token )[2 :- 1 ]  # "b'\xff'" -> "\xff" 
928+                 tokens .append (token_text .encode ("utf-8" ))
929+                 toktypes .append (gguf .TokenType .NORMAL )
930+         remainder  =  vocab_size  -  len (tokens )
931+         assert  remainder  >=  0 
932+         for  i  in  range (len (tokens ), vocab_size ):
933+             tokens .append (f"[PAD{ i }  .encode ("utf-8" ))
934+             toktypes .append (gguf .TokenType .UNUSED )
935+ 
936+         self .gguf_writer .add_tokenizer_model ("rwkv" )
937+         self .gguf_writer .add_token_list (tokens )
938+         self .gguf_writer .add_token_types (toktypes )
939+         special_vocab  =  gguf .SpecialVocab (self .dir_model , load_merges = False )
940+         special_vocab .chat_template  =  "rwkv-world" 
941+         # hack: Add '\n\n' as the EOT token to make it chat normally 
942+         special_vocab ._set_special_token ("eot" , 261 )
943+         special_vocab .add_to_gguf (self .gguf_writer )
944+ 
911945    def  _set_vocab_builtin (self , model_name : Literal ["gpt-neox" , "llama-spm" ], vocab_size : int ):
912946        tokenizer_path  =  Path (sys .path [0 ]) /  "models"  /  f"ggml-vocab-{ model_name }  
913947        logger .warning (f"Using tokenizer from '{ os .path .relpath (tokenizer_path , os .getcwd ())}  )
@@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
34123446    model_arch  =  gguf .MODEL_ARCH .RWKV6 
34133447
34143448    def  set_vocab (self ):
3415-         assert  (self .dir_model  /  "rwkv_vocab_v20230424.txt" ).is_file ()
3416-         vocab_size  =  self .hparams .get ("vocab_size" , 65536 )
3417- 
3418-         tokens : list [bytes ] =  ['<s>' .encode ("utf-8" )]
3419-         toktypes : list [int ] =  [gguf .TokenType .CONTROL ]
3420- 
3421-         with  open (self .dir_model  /  "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as  f :
3422-             lines  =  f .readlines ()
3423-             for  line  in  lines :
3424-                 parts  =  line .split (' ' )
3425-                 assert  len (parts ) >=  3 
3426-                 token , token_len  =  ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
3427-                 token  =  token .encode ("utf-8" ) if  isinstance (token , str ) else  token 
3428-                 assert  isinstance (token , bytes )
3429-                 assert  len (token ) ==  token_len 
3430-                 token_text : str  =  repr (token )[2 :- 1 ]  # "b'\xff'" -> "\xff" 
3431-                 tokens .append (token_text .encode ("utf-8" ))
3432-                 toktypes .append (gguf .TokenType .NORMAL )
3433-         remainder  =  vocab_size  -  len (tokens )
3434-         assert  remainder  >=  0 
3435-         for  i  in  range (len (tokens ), vocab_size ):
3436-             tokens .append (f"[PAD{ i }  .encode ("utf-8" ))
3437-             toktypes .append (gguf .TokenType .UNUSED )
3438- 
3439-         self .gguf_writer .add_tokenizer_model ("rwkv" )
3440-         self .gguf_writer .add_token_list (tokens )
3441-         self .gguf_writer .add_token_types (toktypes )
3442-         special_vocab  =  gguf .SpecialVocab (self .dir_model , load_merges = False )
3443-         special_vocab .chat_template  =  "rwkv-world" 
3444-         # hack: Add '\n\n' as the EOT token to make it chat normally 
3445-         special_vocab ._set_special_token ("eot" , 261 )
3446-         special_vocab .add_to_gguf (self .gguf_writer )
3449+         self ._set_vocab_rwkv_world ()
34473450
34483451    def  set_gguf_parameters (self ):
34493452        block_count  =  self .hparams ["num_hidden_layers" ]
@@ -3565,6 +3568,168 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35653568            yield  (new_name , data )
35663569
35673570
3571+ @Model .register ("Rwkv7ForCausalLM" , "RWKV7ForCausalLM" ) 
3572+ class  Rwkv7Model (Model ):
3573+     model_arch  =  gguf .MODEL_ARCH .RWKV7 
3574+ 
3575+     def  set_vocab (self ):
3576+         self ._set_vocab_rwkv_world ()
3577+ 
3578+     def  calc_lora_rank (self , hidden_size , exponent , multiplier ):
3579+         return  max (1 , round (hidden_size  **  exponent  *  multiplier  /  32 )) *  32 
3580+ 
3581+     def  set_gguf_parameters (self ):
3582+         block_count  =  self .hparams ["num_hidden_layers" ]
3583+         try :
3584+             head_size  =  self .hparams ["head_size" ]
3585+             layer_norm_eps  =  self .hparams ["layer_norm_epsilon" ]
3586+         except  KeyError :
3587+             head_size  =  self .hparams ["head_dim" ]
3588+             layer_norm_eps  =  self .hparams ["norm_eps" ]
3589+         hidden_size  =  self .hparams ["hidden_size" ]
3590+         intermediate_size  =  self .hparams ["intermediate_size" ] if  self .hparams ["intermediate_size" ] is  not None  else  (hidden_size  *  4 )
3591+ 
3592+         # ICLR: In-Context-Learning-Rate 
3593+         try :
3594+             lora_rank_decay  =  self .hparams ["lora_rank_decay" ] if  self .hparams ["lora_rank_decay" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3595+             lora_rank_iclr  =  self .hparams ["lora_rank_iclr" ] if  self .hparams ["lora_rank_iclr" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3596+             lora_rank_value_residual_mix  =  self .hparams ["lora_rank_value_residual_mix" ] if  self .hparams ["lora_rank_value_residual_mix" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3597+             lora_rank_gate  =  self .hparams ["lora_rank_gate" ] if  self .hparams ["lora_rank_gate" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3598+         except  KeyError :
3599+             lora_rank_decay  =  self .hparams ["decay_low_rank_dim" ] if  self .hparams ["decay_low_rank_dim" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3600+             lora_rank_iclr  =  self .hparams ["a_low_rank_dim" ] if  self .hparams ["a_low_rank_dim" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3601+             lora_rank_value_residual_mix  =  self .hparams ["v_low_rank_dim" ] if  self .hparams ["v_low_rank_dim" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3602+             lora_rank_gate  =  self .hparams ["gate_low_rank_dim" ] if  self .hparams ["gate_low_rank_dim" ] is  not None  else  self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3603+ 
3604+         # RWKV isn't context limited 
3605+         self .gguf_writer .add_context_length (1048576 )
3606+         self .gguf_writer .add_embedding_length (hidden_size )
3607+         self .gguf_writer .add_block_count (block_count )
3608+         self .gguf_writer .add_layer_norm_eps (layer_norm_eps )
3609+         self .gguf_writer .add_wkv_head_size (head_size )
3610+         self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3611+         self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3612+         self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3613+         self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3614+         self .gguf_writer .add_feed_forward_length (intermediate_size )
3615+         self .gguf_writer .add_file_type (self .ftype )
3616+ 
3617+         # required by llama.cpp, unused 
3618+         self .gguf_writer .add_head_count (0 )
3619+ 
3620+     lerp_weights : dict [int , dict [str , Tensor ]] =  {}
3621+     lora_needs_transpose : bool  =  True 
3622+ 
3623+     def  modify_tensors (self , data_torch : Tensor , name : str , bid : int  |  None ) ->  Iterable [tuple [str , Tensor ]]:
3624+         # unify tensor names here to make life easier 
3625+         name  =  name .replace ("blocks" , "layers" ).replace ("ffn" , "feed_forward" )
3626+         name  =  name .replace ("self_attn" , "attention" ).replace ("attn" , "attention" )
3627+         name  =  name .replace ("time_mixer." , "" )
3628+         # lora layer names in fla-hub's impl 
3629+         if  "_lora.lora"  in  name :
3630+             self .lora_needs_transpose  =  False 
3631+         name  =  name .replace ("_lora.lora.0.weight" , "1.weight" )
3632+         name  =  name .replace ("_lora.lora.2.weight" , "2.weight" )
3633+         name  =  name .replace ("_lora.lora.2.bias" , "0.weight" )
3634+ 
3635+         name  =  name .replace ("feed_forward_norm" , "ln2" )
3636+         name  =  name .replace ("g_norm" , "ln_x" )
3637+ 
3638+         if  "attention.v"  in  name  and  "value"  not  in self .map_tensor_name (name ) and  bid  ==  0 :
3639+             # some models have dummy v0/v1/v2 on first layer while others don't 
3640+             # ignore them all since they are not used 
3641+             return 
3642+ 
3643+         wkv_has_gate  =  self .hparams .get ("wkv_has_gate" , True )
3644+         lerp_list  =  ["r" , "w" , "k" , "v" , "a" , "g" ] if  wkv_has_gate  else  ["r" , "w" , "k" , "v" , "a" ]
3645+ 
3646+         if  bid  is  not None  and  "attention.x_"  in  name :
3647+             if  "attention.x_x"  in  name :
3648+                 # already concatenated 
3649+                 new_name  =  f"blk.{ bid }  
3650+                 data  =  data_torch .reshape (len (lerp_list ), 1 , 1 , - 1 )
3651+                 yield  (new_name , data )
3652+             else :
3653+                 try :
3654+                     self .lerp_weights [bid ][name ] =  data_torch 
3655+                 except  KeyError :
3656+                     self .lerp_weights [bid ] =  {name : data_torch }
3657+                 if  all (f"model.layers.{ bid } { i }   in  self .lerp_weights [bid ].keys () for  i  in  lerp_list ):
3658+                     new_name  =  f"blk.{ bid }  
3659+                     data  =  torch .stack ([self .lerp_weights [bid ][f"model.layers.{ bid } { i }  ] for  i  in  lerp_list ], dim = 0 )
3660+                     yield  (new_name , data )
3661+             return 
3662+         else :
3663+             data_torch  =  data_torch .squeeze ()
3664+             new_name  =  self .map_tensor_name (name )
3665+ 
3666+             if  not  (new_name .endswith (".weight" ) or  new_name .endswith (".bias" )):
3667+                 new_name  +=  ".weight" 
3668+ 
3669+             if  self .lora_needs_transpose  and  any (
3670+                 new_name .endswith (t ) for  t  in  [
3671+                     "time_mix_w1.weight" , "time_mix_w2.weight" ,
3672+                     "time_mix_a1.weight" , "time_mix_a2.weight" ,
3673+                     "time_mix_v1.weight" , "time_mix_v2.weight" ,
3674+                     "time_mix_g1.weight" , "time_mix_g2.weight" ,
3675+                 ]
3676+             ):
3677+                 data_torch  =  data_torch .transpose (0 , 1 )
3678+ 
3679+             if  'r_k'  in  new_name :
3680+                 data_torch  =  data_torch .flatten ()
3681+ 
3682+             if  bid  ==  0  and  "time_mix_a"  in  new_name :
3683+                 # dummy v0/v1/v2 on first layer 
3684+                 # easist way to make llama happy 
3685+                 yield  (new_name .replace ("time_mix_a" , "time_mix_v" ), data_torch )
3686+ 
3687+             yield  (new_name , data_torch )
3688+ 
3689+ 
3690+ @Model .register ("RwkvHybridForCausalLM" ) 
3691+ class  ARwkv7Model (Rwkv7Model ):
3692+     model_arch  =  gguf .MODEL_ARCH .ARWKV7 
3693+ 
3694+     def  set_vocab (self ):
3695+         try :
3696+             self ._set_vocab_sentencepiece ()
3697+         except  FileNotFoundError :
3698+             self ._set_vocab_gpt2 ()
3699+ 
3700+     def  set_gguf_parameters (self ):
3701+         block_count  =  self .hparams ["num_hidden_layers" ]
3702+         hidden_size  =  self .hparams ["hidden_size" ]
3703+         head_size  =  self .hparams ["head_size" ]
3704+         rms_norm_eps  =  self .hparams ["rms_norm_eps" ]
3705+         intermediate_size  =  self .hparams ["intermediate_size" ]
3706+         wkv_has_gate  =  self .hparams ["wkv_has_gate" ]
3707+         assert  self .hparams ["wkv_version" ] ==  7 
3708+ 
3709+         # ICLR: In-Context-Learning-Rate 
3710+         lora_rank_decay  =  64 
3711+         lora_rank_iclr  =  64 
3712+         lora_rank_value_residual_mix  =  32 
3713+         lora_rank_gate  =  128  if  wkv_has_gate  else  0 
3714+ 
3715+         # RWKV isn't context limited 
3716+         self .gguf_writer .add_context_length (1048576 )
3717+         self .gguf_writer .add_embedding_length (hidden_size )
3718+         self .gguf_writer .add_block_count (block_count )
3719+         self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
3720+         self .gguf_writer .add_wkv_head_size (head_size )
3721+         self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3722+         self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3723+         self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3724+         self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3725+         self .gguf_writer .add_feed_forward_length (intermediate_size )
3726+         self .gguf_writer .add_file_type (self .ftype )
3727+         self .gguf_writer .add_token_shift_count (1 )
3728+ 
3729+         # required by llama.cpp, unused 
3730+         self .gguf_writer .add_head_count (0 )
3731+ 
3732+ 
35683733@Model .register ("MambaForCausalLM" , "MambaLMHeadModel" , "FalconMambaForCausalLM" ) 
35693734class  MambaModel (Model ):
35703735    model_arch  =  gguf .MODEL_ARCH .MAMBA 
0 commit comments