@@ -2264,13 +2264,6 @@ def set_vocab(self):
22642264
22652265 special_vocab .add_to_gguf (self .gguf_writer )
22662266
2267- def _hf_permute_qk (self , weights , n_head : int , n_head_kv : int ):
2268- if n_head_kv is not None and n_head != n_head_kv :
2269- n_head = n_head_kv
2270- return (weights .reshape (n_head , 2 , weights .shape [0 ] // n_head // 2 , * weights .shape [1 :])
2271- .swapaxes (1 , 2 )
2272- .reshape (weights .shape ))
2273-
22742267 def set_gguf_parameters (self ):
22752268 self .gguf_writer .add_name ("InternLM2" )
22762269 self .gguf_writer .add_context_length (self .hparams ["max_position_embeddings" ])
@@ -2290,26 +2283,22 @@ def set_gguf_parameters(self):
22902283 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
22912284 num_heads = self .hparams ["num_attention_heads" ]
22922285 num_kv_heads = self .hparams ["num_key_value_heads" ]
2293- hidden_size = self .hparams ["hidden_size" ]
2286+ n_embd = self .hparams ["hidden_size" ]
22942287 q_per_kv = num_heads // num_kv_heads
2295- head_dim = hidden_size // num_heads
2288+ head_dim = n_embd // num_heads
22962289 num_groups = num_heads // q_per_kv
22972290
2298- qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
2299-
2300- if re .match (qkv_pattern , name ):
2301- bid = re .findall (qkv_pattern , name )[0 ]
2291+ if bid is not None and f"model.layers.{ bid } .attention.wqkv" in name :
23022292 qkv = data_torch
2303- # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
2304- qkv = qkv .T .reshape ((- 1 , num_groups , q_per_kv + 2 , head_dim ))
2305- q , k , v = qkv [..., : q_per_kv , :], qkv [..., q_per_kv : q_per_kv + 1 , :], qkv [..., q_per_kv + 1 : q_per_kv + 2 , :]
2293+
2294+ qkv = qkv .reshape ((num_groups , q_per_kv + 2 , head_dim , n_embd ))
2295+ q , k , v = qkv [:, : q_per_kv ], qkv [:, - 2 ], qkv [:, - 1 ]
2296+
23062297 # The model weights of q and k equire additional reshape.
2307- # q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
2308- q = self ._hf_permute_qk (q .reshape ((q .shape [0 ], - 1 )).T , num_heads , num_heads )
2309- # k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
2310- k = self ._hf_permute_qk (k .reshape ((k .shape [0 ], - 1 )).T , num_heads , num_kv_heads )
2311- # v = rearrange(v, " o g n i -> o (g n i)").T
2312- v = v .reshape ((v .shape [0 ], - 1 )).T
2298+ q = LlamaModel .permute (q .reshape ((- 1 , q .shape [- 1 ])), num_heads , num_heads )
2299+ k = LlamaModel .permute (k .reshape ((- 1 , k .shape [- 1 ])), num_heads , num_kv_heads )
2300+ v = v .reshape ((- 1 , v .shape [- 1 ]))
2301+
23132302 return [
23142303 (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_Q , bid ), q ),
23152304 (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_K , bid ), k ),
@@ -3585,6 +3574,7 @@ def main() -> None:
35853574 small_first_shard = args .no_tensor_first_split )
35863575
35873576 logger .info ("Set model parameters" )
3577+ model_instance .gguf_writer .add_type (gguf .GGUFType .MODEL )
35883578 model_instance .set_gguf_parameters ()
35893579
35903580 logger .info ("Set model tokenizer" )
0 commit comments