@@ -10,7 +10,7 @@ class TensorNameMap:
1010        # Token embeddings 
1111        MODEL_TENSOR .TOKEN_EMBD : (
1212            "gpt_neox.embed_in" ,                         # gptneox 
13-             "transformer.wte" ,                           # gpt2 gpt-j mpt refact qwen dbrx 
13+             "transformer.wte" ,                           # gpt2 gpt-j mpt refact qwen dbrx jais  
1414            "transformer.word_embeddings" ,               # falcon 
1515            "word_embeddings" ,                           # bloom 
1616            "model.embed_tokens" ,                        # llama-hf 
@@ -49,7 +49,7 @@ class TensorNameMap:
4949        # Output 
5050        MODEL_TENSOR .OUTPUT : (
5151            "embed_out" ,                 # gptneox 
52-             "lm_head" ,                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx 
52+             "lm_head" ,                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais  
5353            "output" ,                    # llama-pth bloom internlm2 
5454            "word_embeddings_for_head" ,  # persimmon 
5555            "lm_head.linear" ,            # phi2 
@@ -58,7 +58,7 @@ class TensorNameMap:
5858        # Output norm 
5959        MODEL_TENSOR .OUTPUT_NORM : (
6060            "gpt_neox.final_layer_norm" ,               # gptneox 
61-             "transformer.ln_f" ,                        # gpt2 gpt-j falcon 
61+             "transformer.ln_f" ,                        # gpt2 gpt-j falcon jais  
6262            "model.norm" ,                              # llama-hf baichuan internlm2 
6363            "norm" ,                                    # llama-pth 
6464            "transformer.norm_f" ,                      # mpt dbrx 
@@ -81,7 +81,7 @@ class TensorNameMap:
8181        # Attention norm 
8282        MODEL_TENSOR .ATTN_NORM : (
8383            "gpt_neox.layers.{bid}.input_layernorm" ,                # gptneox 
84-             "transformer.h.{bid}.ln_1" ,                             # gpt2 gpt-j refact qwen 
84+             "transformer.h.{bid}.ln_1" ,                             # gpt2 gpt-j refact qwen jais  
8585            "transformer.blocks.{bid}.norm_1" ,                      # mpt 
8686            "transformer.h.{bid}.input_layernorm" ,                  # falcon7b 
8787            "h.{bid}.input_layernorm" ,                              # bloom 
@@ -109,7 +109,7 @@ class TensorNameMap:
109109        # Attention query-key-value 
110110        MODEL_TENSOR .ATTN_QKV : (
111111            "gpt_neox.layers.{bid}.attention.query_key_value" ,                     # gptneox 
112-             "transformer.h.{bid}.attn.c_attn" ,                                     # gpt2 qwen 
112+             "transformer.h.{bid}.attn.c_attn" ,                                     # gpt2 qwen jais  
113113            "transformer.blocks.{bid}.attn.Wqkv" ,                                  # mpt 
114114            "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv" ,                   # dbrx 
115115            "transformer.h.{bid}.self_attention.query_key_value" ,                  # falcon 
@@ -160,7 +160,7 @@ class TensorNameMap:
160160        # Attention output 
161161        MODEL_TENSOR .ATTN_OUT : (
162162            "gpt_neox.layers.{bid}.attention.dense" ,                        # gptneox 
163-             "transformer.h.{bid}.attn.c_proj" ,                              # gpt2 refact qwen 
163+             "transformer.h.{bid}.attn.c_proj" ,                              # gpt2 refact qwen jais  
164164            "transformer.blocks.{bid}.attn.out_proj" ,                       # mpt 
165165            "transformer.h.{bid}.self_attention.dense" ,                     # falcon 
166166            "h.{bid}.self_attention.dense" ,                                 # bloom 
@@ -198,7 +198,7 @@ class TensorNameMap:
198198        # Feed-forward norm 
199199        MODEL_TENSOR .FFN_NORM : (
200200            "gpt_neox.layers.{bid}.post_attention_layernorm" ,                # gptneox 
201-             "transformer.h.{bid}.ln_2" ,                                      # gpt2 refact qwen 
201+             "transformer.h.{bid}.ln_2" ,                                      # gpt2 refact qwen jais  
202202            "h.{bid}.post_attention_layernorm" ,                              # bloom 
203203            "transformer.blocks.{bid}.norm_2" ,                               # mpt 
204204            "model.layers.{bid}.post_attention_layernorm" ,                   # llama-hf 
@@ -225,7 +225,7 @@ class TensorNameMap:
225225        # Feed-forward up 
226226        MODEL_TENSOR .FFN_UP : (
227227            "gpt_neox.layers.{bid}.mlp.dense_h_to_4h" ,                # gptneox 
228-             "transformer.h.{bid}.mlp.c_fc" ,                           # gpt2 
228+             "transformer.h.{bid}.mlp.c_fc" ,                           # gpt2 jais  
229229            "transformer.blocks.{bid}.ffn.up_proj" ,                   # mpt 
230230            "transformer.h.{bid}.mlp.dense_h_to_4h" ,                  # falcon 
231231            "h.{bid}.mlp.dense_h_to_4h" ,                              # bloom 
@@ -271,6 +271,7 @@ class TensorNameMap:
271271            "model.layers.{bid}.mlp.gate_proj" ,           # llama-hf refact 
272272            "layers.{bid}.feed_forward.w1" ,               # llama-pth 
273273            "transformer.h.{bid}.mlp.w2" ,                 # qwen 
274+             "transformer.h.{bid}.mlp.c_fc2" ,              # jais 
274275            "model.layers.layers.{bid}.mlp.gate_proj" ,    # plamo 
275276            "model.layers.{bid}.feed_forward.w1" ,         # internlm2 
276277            "encoder.layers.{bid}.mlp.fc12" ,              # nomic-bert 
@@ -294,7 +295,7 @@ class TensorNameMap:
294295        # Feed-forward down 
295296        MODEL_TENSOR .FFN_DOWN : (
296297            "gpt_neox.layers.{bid}.mlp.dense_4h_to_h" ,                # gptneox 
297-             "transformer.h.{bid}.mlp.c_proj" ,                         # gpt2 refact qwen 
298+             "transformer.h.{bid}.mlp.c_proj" ,                         # gpt2 refact qwen jais  
298299            "transformer.blocks.{bid}.ffn.down_proj" ,                 # mpt 
299300            "transformer.h.{bid}.mlp.dense_4h_to_h" ,                  # falcon 
300301            "h.{bid}.mlp.dense_4h_to_h" ,                              # bloom 
0 commit comments