@@ -908,6 +908,40 @@ def _set_vocab_llama_hf(self):
908908 special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
909909 special_vocab .add_to_gguf (self .gguf_writer )
910910
911+ def _set_vocab_rwkv_world (self ):
912+ assert (self .dir_model / "rwkv_vocab_v20230424.txt" ).is_file ()
913+ vocab_size = self .hparams .get ("vocab_size" , 65536 )
914+
915+ tokens : list [bytes ] = ['<s>' .encode ("utf-8" )]
916+ toktypes : list [int ] = [gguf .TokenType .CONTROL ]
917+
918+ with open (self .dir_model / "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as f :
919+ lines = f .readlines ()
920+ for line in lines :
921+ parts = line .split (' ' )
922+ assert len (parts ) >= 3
923+ token , token_len = ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
924+ token = token .encode ("utf-8" ) if isinstance (token , str ) else token
925+ assert isinstance (token , bytes )
926+ assert len (token ) == token_len
927+ token_text : str = repr (token )[2 :- 1 ] # "b'\xff'" -> "\xff"
928+ tokens .append (token_text .encode ("utf-8" ))
929+ toktypes .append (gguf .TokenType .NORMAL )
930+ remainder = vocab_size - len (tokens )
931+ assert remainder >= 0
932+ for i in range (len (tokens ), vocab_size ):
933+ tokens .append (f"[PAD{ i } ]" .encode ("utf-8" ))
934+ toktypes .append (gguf .TokenType .UNUSED )
935+
936+ self .gguf_writer .add_tokenizer_model ("rwkv" )
937+ self .gguf_writer .add_token_list (tokens )
938+ self .gguf_writer .add_token_types (toktypes )
939+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
940+ special_vocab .chat_template = "rwkv-world"
941+ # hack: Add '\n\n' as the EOT token to make it chat normally
942+ special_vocab ._set_special_token ("eot" , 261 )
943+ special_vocab .add_to_gguf (self .gguf_writer )
944+
911945 def _set_vocab_builtin (self , model_name : Literal ["gpt-neox" , "llama-spm" ], vocab_size : int ):
912946 tokenizer_path = Path (sys .path [0 ]) / "models" / f"ggml-vocab-{ model_name } .gguf"
913947 logger .warning (f"Using tokenizer from '{ os .path .relpath (tokenizer_path , os .getcwd ())} '" )
@@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
34123446 model_arch = gguf .MODEL_ARCH .RWKV6
34133447
34143448 def set_vocab (self ):
3415- assert (self .dir_model / "rwkv_vocab_v20230424.txt" ).is_file ()
3416- vocab_size = self .hparams .get ("vocab_size" , 65536 )
3417-
3418- tokens : list [bytes ] = ['<s>' .encode ("utf-8" )]
3419- toktypes : list [int ] = [gguf .TokenType .CONTROL ]
3420-
3421- with open (self .dir_model / "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as f :
3422- lines = f .readlines ()
3423- for line in lines :
3424- parts = line .split (' ' )
3425- assert len (parts ) >= 3
3426- token , token_len = ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
3427- token = token .encode ("utf-8" ) if isinstance (token , str ) else token
3428- assert isinstance (token , bytes )
3429- assert len (token ) == token_len
3430- token_text : str = repr (token )[2 :- 1 ] # "b'\xff'" -> "\xff"
3431- tokens .append (token_text .encode ("utf-8" ))
3432- toktypes .append (gguf .TokenType .NORMAL )
3433- remainder = vocab_size - len (tokens )
3434- assert remainder >= 0
3435- for i in range (len (tokens ), vocab_size ):
3436- tokens .append (f"[PAD{ i } ]" .encode ("utf-8" ))
3437- toktypes .append (gguf .TokenType .UNUSED )
3438-
3439- self .gguf_writer .add_tokenizer_model ("rwkv" )
3440- self .gguf_writer .add_token_list (tokens )
3441- self .gguf_writer .add_token_types (toktypes )
3442- special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
3443- special_vocab .chat_template = "rwkv-world"
3444- # hack: Add '\n\n' as the EOT token to make it chat normally
3445- special_vocab ._set_special_token ("eot" , 261 )
3446- special_vocab .add_to_gguf (self .gguf_writer )
3449+ self ._set_vocab_rwkv_world ()
34473450
34483451 def set_gguf_parameters (self ):
34493452 block_count = self .hparams ["num_hidden_layers" ]
@@ -3565,6 +3568,168 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35653568 yield (new_name , data )
35663569
35673570
3571+ @Model .register ("Rwkv7ForCausalLM" , "RWKV7ForCausalLM" )
3572+ class Rwkv7Model (Model ):
3573+ model_arch = gguf .MODEL_ARCH .RWKV7
3574+
3575+ def set_vocab (self ):
3576+ self ._set_vocab_rwkv_world ()
3577+
3578+ def calc_lora_rank (self , hidden_size , exponent , multiplier ):
3579+ return max (1 , round (hidden_size ** exponent * multiplier / 32 )) * 32
3580+
3581+ def set_gguf_parameters (self ):
3582+ block_count = self .hparams ["num_hidden_layers" ]
3583+ try :
3584+ head_size = self .hparams ["head_size" ]
3585+ layer_norm_eps = self .hparams ["layer_norm_epsilon" ]
3586+ except KeyError :
3587+ head_size = self .hparams ["head_dim" ]
3588+ layer_norm_eps = self .hparams ["norm_eps" ]
3589+ hidden_size = self .hparams ["hidden_size" ]
3590+ intermediate_size = self .hparams ["intermediate_size" ] if self .hparams ["intermediate_size" ] is not None else (hidden_size * 4 )
3591+
3592+ # ICLR: In-Context-Learning-Rate
3593+ try :
3594+ lora_rank_decay = self .hparams ["lora_rank_decay" ] if self .hparams ["lora_rank_decay" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3595+ lora_rank_iclr = self .hparams ["lora_rank_iclr" ] if self .hparams ["lora_rank_iclr" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3596+ lora_rank_value_residual_mix = self .hparams ["lora_rank_value_residual_mix" ] if self .hparams ["lora_rank_value_residual_mix" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3597+ lora_rank_gate = self .hparams ["lora_rank_gate" ] if self .hparams ["lora_rank_gate" ] is not None else self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3598+ except KeyError :
3599+ lora_rank_decay = self .hparams ["decay_low_rank_dim" ] if self .hparams ["decay_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3600+ lora_rank_iclr = self .hparams ["a_low_rank_dim" ] if self .hparams ["a_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3601+ lora_rank_value_residual_mix = self .hparams ["v_low_rank_dim" ] if self .hparams ["v_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3602+ lora_rank_gate = self .hparams ["gate_low_rank_dim" ] if self .hparams ["gate_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3603+
3604+ # RWKV isn't context limited
3605+ self .gguf_writer .add_context_length (1048576 )
3606+ self .gguf_writer .add_embedding_length (hidden_size )
3607+ self .gguf_writer .add_block_count (block_count )
3608+ self .gguf_writer .add_layer_norm_eps (layer_norm_eps )
3609+ self .gguf_writer .add_wkv_head_size (head_size )
3610+ self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3611+ self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3612+ self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3613+ self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3614+ self .gguf_writer .add_feed_forward_length (intermediate_size )
3615+ self .gguf_writer .add_file_type (self .ftype )
3616+
3617+ # required by llama.cpp, unused
3618+ self .gguf_writer .add_head_count (0 )
3619+
3620+ lerp_weights : dict [int , dict [str , Tensor ]] = {}
3621+ lora_needs_transpose : bool = True
3622+
3623+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
3624+ # unify tensor names here to make life easier
3625+ name = name .replace ("blocks" , "layers" ).replace ("ffn" , "feed_forward" )
3626+ name = name .replace ("self_attn" , "attention" ).replace ("attn" , "attention" )
3627+ name = name .replace ("time_mixer." , "" )
3628+ # lora layer names in fla-hub's impl
3629+ if "_lora.lora" in name :
3630+ self .lora_needs_transpose = False
3631+ name = name .replace ("_lora.lora.0.weight" , "1.weight" )
3632+ name = name .replace ("_lora.lora.2.weight" , "2.weight" )
3633+ name = name .replace ("_lora.lora.2.bias" , "0.weight" )
3634+
3635+ name = name .replace ("feed_forward_norm" , "ln2" )
3636+ name = name .replace ("g_norm" , "ln_x" )
3637+
3638+ if "attention.v" in name and "value" not in self .map_tensor_name (name ) and bid == 0 :
3639+ # some models have dummy v0/v1/v2 on first layer while others don't
3640+ # ignore them all since they are not used
3641+ return
3642+
3643+ wkv_has_gate = self .hparams .get ("wkv_has_gate" , True )
3644+ lerp_list = ["r" , "w" , "k" , "v" , "a" , "g" ] if wkv_has_gate else ["r" , "w" , "k" , "v" , "a" ]
3645+
3646+ if bid is not None and "attention.x_" in name :
3647+ if "attention.x_x" in name :
3648+ # already concatenated
3649+ new_name = f"blk.{ bid } .time_mix_lerp_fused.weight"
3650+ data = data_torch .reshape (len (lerp_list ), 1 , 1 , - 1 )
3651+ yield (new_name , data )
3652+ else :
3653+ try :
3654+ self .lerp_weights [bid ][name ] = data_torch
3655+ except KeyError :
3656+ self .lerp_weights [bid ] = {name : data_torch }
3657+ if all (f"model.layers.{ bid } .attention.x_{ i } " in self .lerp_weights [bid ].keys () for i in lerp_list ):
3658+ new_name = f"blk.{ bid } .time_mix_lerp_fused.weight"
3659+ data = torch .stack ([self .lerp_weights [bid ][f"model.layers.{ bid } .attention.x_{ i } " ] for i in lerp_list ], dim = 0 )
3660+ yield (new_name , data )
3661+ return
3662+ else :
3663+ data_torch = data_torch .squeeze ()
3664+ new_name = self .map_tensor_name (name )
3665+
3666+ if not (new_name .endswith (".weight" ) or new_name .endswith (".bias" )):
3667+ new_name += ".weight"
3668+
3669+ if self .lora_needs_transpose and any (
3670+ new_name .endswith (t ) for t in [
3671+ "time_mix_w1.weight" , "time_mix_w2.weight" ,
3672+ "time_mix_a1.weight" , "time_mix_a2.weight" ,
3673+ "time_mix_v1.weight" , "time_mix_v2.weight" ,
3674+ "time_mix_g1.weight" , "time_mix_g2.weight" ,
3675+ ]
3676+ ):
3677+ data_torch = data_torch .transpose (0 , 1 )
3678+
3679+ if 'r_k' in new_name :
3680+ data_torch = data_torch .flatten ()
3681+
3682+ if bid == 0 and "time_mix_a" in new_name :
3683+ # dummy v0/v1/v2 on first layer
3684+ # easist way to make llama happy
3685+ yield (new_name .replace ("time_mix_a" , "time_mix_v" ), data_torch )
3686+
3687+ yield (new_name , data_torch )
3688+
3689+
3690+ @Model .register ("RwkvHybridForCausalLM" )
3691+ class ARwkv7Model (Rwkv7Model ):
3692+ model_arch = gguf .MODEL_ARCH .ARWKV7
3693+
3694+ def set_vocab (self ):
3695+ try :
3696+ self ._set_vocab_sentencepiece ()
3697+ except FileNotFoundError :
3698+ self ._set_vocab_gpt2 ()
3699+
3700+ def set_gguf_parameters (self ):
3701+ block_count = self .hparams ["num_hidden_layers" ]
3702+ hidden_size = self .hparams ["hidden_size" ]
3703+ head_size = self .hparams ["head_size" ]
3704+ rms_norm_eps = self .hparams ["rms_norm_eps" ]
3705+ intermediate_size = self .hparams ["intermediate_size" ]
3706+ wkv_has_gate = self .hparams ["wkv_has_gate" ]
3707+ assert self .hparams ["wkv_version" ] == 7
3708+
3709+ # ICLR: In-Context-Learning-Rate
3710+ lora_rank_decay = 64
3711+ lora_rank_iclr = 64
3712+ lora_rank_value_residual_mix = 32
3713+ lora_rank_gate = 128 if wkv_has_gate else 0
3714+
3715+ # RWKV isn't context limited
3716+ self .gguf_writer .add_context_length (1048576 )
3717+ self .gguf_writer .add_embedding_length (hidden_size )
3718+ self .gguf_writer .add_block_count (block_count )
3719+ self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
3720+ self .gguf_writer .add_wkv_head_size (head_size )
3721+ self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3722+ self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3723+ self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3724+ self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3725+ self .gguf_writer .add_feed_forward_length (intermediate_size )
3726+ self .gguf_writer .add_file_type (self .ftype )
3727+ self .gguf_writer .add_token_shift_count (1 )
3728+
3729+ # required by llama.cpp, unused
3730+ self .gguf_writer .add_head_count (0 )
3731+
3732+
35683733@Model .register ("MambaForCausalLM" , "MambaLMHeadModel" , "FalconMambaForCausalLM" )
35693734class MambaModel (Model ):
35703735 model_arch = gguf .MODEL_ARCH .MAMBA
0 commit comments