@@ -104,7 +104,7 @@ def make_tensors_list() -> List[str]:
104104
105105def  find_n_mult (n_ff : int , n_embd : int ) ->  int :
106106    # hardcoded magic range 
107-     for  n_mult  in  range (256 , 1 , - 1 ):
107+     for  n_mult  in  range (8192 , 1 , - 1 ):
108108        calc_ff  =  (((8 * n_embd ) //  3  +  n_mult  -  1 ) //  n_mult )* n_mult 
109109        if  calc_ff  ==  n_ff :
110110            return  n_mult 
@@ -113,11 +113,12 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
113113
114114@dataclass  
115115class  Params :
116-     n_vocab : int 
117-     n_embd :  int 
118-     n_mult :  int 
119-     n_head :  int 
120-     n_layer : int 
116+     n_vocab :   int 
117+     n_embd :    int 
118+     n_mult :    int 
119+     n_head :    int 
120+     n_layer :   int 
121+     n_kv_head : Optional [int ]  # This parameter is only used for Llama 2 
121122
122123    @staticmethod  
123124    def  guessed (model : 'LazyModel' ) ->  'Params' :
@@ -139,31 +140,34 @@ def guessed(model: 'LazyModel') -> 'Params':
139140        n_head = n_embd  //  128  # guessed 
140141
141142        return  Params (
142-             n_vocab  =  n_vocab ,
143-             n_embd   =  n_embd ,
144-             n_mult   =  256 ,
145-             n_head   =  n_head ,
146-             n_layer  =  n_layer ,
143+             n_vocab    =  n_vocab ,
144+             n_embd     =  n_embd ,
145+             n_mult     =  256 ,
146+             n_head     =  n_head ,
147+             n_layer    =  n_layer ,
148+             n_kv_head  =  None ,
147149        )
148150
149151    @staticmethod  
150152    def  loadHFTransformerJson (model : 'LazyModel' , config_path : 'Path' ) ->  'Params' :
151153        config  =  json .load (open (config_path ))
152154
153-         n_vocab  =  config ["vocab_size" ];
154-         n_embd   =  config ["hidden_size" ];
155-         n_head   =  config ["num_attention_heads" ];
156-         n_layer  =  config ["num_hidden_layers" ];
157-         n_ff     =  config ["intermediate_size" ];
155+         n_vocab    =  config ["vocab_size" ];
156+         n_embd     =  config ["hidden_size" ];
157+         n_head     =  config ["num_attention_heads" ];
158+         n_layer    =  config ["num_hidden_layers" ];
159+         n_ff       =  config ["intermediate_size" ];
160+         n_kv_head  =  config .get ("num_key_value_heads" )
158161
159162        n_mult  =  find_n_mult (n_ff , n_embd );
160163
161164        return  Params (
162-             n_vocab  =  n_vocab ,
163-             n_embd   =  n_embd ,
164-             n_mult   =  n_mult ,
165-             n_head   =  n_head ,
166-             n_layer  =  n_layer ,
165+             n_vocab    =  n_vocab ,
166+             n_embd     =  n_embd ,
167+             n_mult     =  n_mult ,
168+             n_head     =  n_head ,
169+             n_layer    =  n_layer ,
170+             n_kv_head  =  n_kv_head ,
167171        )
168172
169173    # LLaMA v2 70B params.json 
@@ -182,11 +186,12 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
182186            n_vocab  =  model ["tok_embeddings.weight" ].shape [0 ]
183187
184188        return  Params (
185-             n_vocab  =  n_vocab ,
186-             n_embd   =  n_embd ,
187-             n_mult   =  n_mult ,
188-             n_head   =  n_head ,
189-             n_layer  =  n_layer ,
189+             n_vocab    =  n_vocab ,
190+             n_embd     =  n_embd ,
191+             n_mult     =  n_mult ,
192+             n_head     =  n_head ,
193+             n_layer    =  n_layer ,
194+             n_kv_head  =  None ,
190195        )
191196
192197    @staticmethod  
@@ -293,10 +298,12 @@ def __repr__(self) -> str:
293298Vocab  =  Union [BpeVocab , SentencePieceVocab ]
294299
295300
296- def  permute (weights : NDArray , n_head : int ) ->  NDArray :
301+ def  permute (weights : NDArray , n_head : int , n_kv_head : Optional [int ] =  None ) ->  NDArray :
302+     if  n_kv_head  is  not None  and  n_head  !=  n_kv_head :
303+         n_head  //=  n_kv_head 
297304    return  (weights .reshape (n_head , 2 , weights .shape [0 ] //  n_head  //  2 , * weights .shape [1 :])
298-                     .swapaxes (1 , 2 )
299-                     .reshape (weights .shape ))
305+                 .swapaxes (1 , 2 )
306+                 .reshape (weights .shape ))
300307
301308
302309class  Tensor (metaclass = ABCMeta ):
@@ -305,7 +312,7 @@ class Tensor(metaclass=ABCMeta):
305312    @abstractmethod  
306313    def  astype (self , data_type : DataType ) ->  'Tensor' : ...
307314    @abstractmethod  
308-     def  permute (self , n_head : int ) ->  'Tensor' : ...
315+     def  permute (self , n_head : int ,  n_kv_head :  Optional [ int ]  =   None ) ->  'Tensor' : ...
309316    @abstractmethod  
310317    def  permute_part (self , n_part : int , n_head : int ) ->  'UnquantizedTensor' : ...
311318    @abstractmethod  
@@ -343,8 +350,8 @@ def part(self, n_part: int) -> 'UnquantizedTensor':
343350        r  =  self .ndarray .shape [0 ] //  3 
344351        return  UnquantizedTensor (self .ndarray [r  *  n_part  : r  *  n_part  +  r , ...])
345352
346-     def  permute (self , n_head : int ) ->  'UnquantizedTensor' :
347-         return  UnquantizedTensor (permute (self .ndarray , n_head ))
353+     def  permute (self , n_head : int ,  n_kv_head :  Optional [ int ]  =   None ) ->  'UnquantizedTensor' :
354+         return  UnquantizedTensor (permute (self .ndarray , n_head ,  n_kv_head ))
348355
349356
350357def  load_unquantized (lazy_tensor : 'LazyTensor' , expected_dtype : Any  =  None , convert : bool  =  False ) ->  NDArray :
@@ -367,18 +374,18 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
367374
368375
369376class  DeferredPermutedTensor (Tensor ):
370-     def  __init__ (self , base : Tensor , n_head : int ) ->  None :
377+     def  __init__ (self , base : Tensor , n_head : int ,  n_kv_head :  Optional [ int ]  =   None ) ->  None :
371378        self .base  =  base 
372379        self .n_head  =  n_head 
373380        self .data_type  =  self .base .data_type 
374381
375382    def  astype (self , data_type : DataType ) ->  Tensor :
376-         return  self .base .astype (data_type ).permute (self .n_head )
383+         return  self .base .astype (data_type ).permute (self .n_head ,  self . n_kv_head )
377384
378385    def  to_ggml (self ) ->  GGMLCompatibleTensor :
379-         return  self .base .to_ggml ().permute (self .n_head )
386+         return  self .base .to_ggml ().permute (self .n_head ,  self . n_kv_head )
380387
381-     def  permute (self , n_head : int ) ->  Tensor :
388+     def  permute (self , n_head : int ,  n_kv_head :  Optional [ int ]  =   None ) ->  Tensor :
382389        raise  Exception ("shouldn't permute twice" )
383390
384391
@@ -474,10 +481,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
474481    return  ModelPlus (model , paths , format , vocab )
475482
476483
477- def  permute_lazy (lazy_tensor : LazyTensor , n_head : int ) ->  LazyTensor :
484+ def  permute_lazy (lazy_tensor : LazyTensor , n_head : int ,  n_kv_head :  Optional [ int ]  =   None ) ->  LazyTensor :
478485    def  load () ->  Tensor :
479-         return  lazy_tensor .load ().permute (n_head )
480-     return  LazyTensor (load , lazy_tensor .shape , lazy_tensor .data_type , f'permute({ n_head }   +  lazy_tensor .description )
486+         return  lazy_tensor .load ().permute (n_head ,  n_kv_head )
487+     return  LazyTensor (load , lazy_tensor .shape , lazy_tensor .data_type , f'permute({ n_head } ,  { n_kv_head }   +  lazy_tensor .description )
481488
482489def  permute_part_lazy (lazy_tensor : LazyTensor , n_part : int , n_head : int ) ->  LazyTensor :
483490    def  load () ->  Tensor :
@@ -502,7 +509,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
502509    for  i  in  itertools .count ():
503510        if  f"model.layers.{ i }   in  model :
504511            out [f"layers.{ i }  ] =  permute_lazy (model [f"model.layers.{ i }  ], params .n_head )
505-             out [f"layers.{ i }  ] =  permute_lazy (model [f"model.layers.{ i }  ], params .n_head )
512+             out [f"layers.{ i }  ] =  permute_lazy (model [f"model.layers.{ i }  ], params .n_head ,  params . n_kv_head )
506513            out [f"layers.{ i }  ] =  model [f"model.layers.{ i }  ]
507514        elif  f"model.layers.{ i }   in  model :
508515            out [f"layers.{ i }  ] =  permute_part_lazy (model [f"model.layers.{ i }  ], 0 , params .n_head )
0 commit comments