@@ -149,6 +149,10 @@ class TensorNameMap:
149
149
"model.layers.{bid}.ln2" , # yi
150
150
),
151
151
152
+ MODEL_TENSOR .FFN_GATE_INP : (
153
+ "layers.{bid}.feed_forward.gate" , # mixtral
154
+ ),
155
+
152
156
# Feed-forward up
153
157
MODEL_TENSOR .FFN_UP : (
154
158
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h" , # gptneox
@@ -164,11 +168,19 @@ class TensorNameMap:
164
168
"transformer.h.{bid}.mlp.w1" , # qwen
165
169
),
166
170
171
+ MODEL_TENSOR .FFN_UP_EXP : (
172
+ "layers.{bid}.feed_forward.experts.{xid}.w3" , # mixtral
173
+ ),
174
+
167
175
# Feed-forward gate
168
176
MODEL_TENSOR .FFN_GATE : (
169
- "model.layers.{bid}.mlp.gate_proj" , # llama-hf refact
170
- "layers.{bid}.feed_forward.w1" , # llama-pth
171
- "transformer.h.{bid}.mlp.w2" , # qwen
177
+ "model.layers.{bid}.mlp.gate_proj" , # llama-hf refact
178
+ "layers.{bid}.feed_forward.w1" , # llama-pth
179
+ "transformer.h.{bid}.mlp.w2" , # qwen
180
+ ),
181
+
182
+ MODEL_TENSOR .FFN_GATE_EXP : (
183
+ "layers.{bid}.feed_forward.experts.{xid}.w1" , # mixtral
172
184
),
173
185
174
186
# Feed-forward down
@@ -185,6 +197,10 @@ class TensorNameMap:
185
197
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h" , # persimmon
186
198
),
187
199
200
+ MODEL_TENSOR .FFN_DOWN_EXP : (
201
+ "layers.{bid}.feed_forward.experts.{xid}.w2" , # mixtral
202
+ ),
203
+
188
204
MODEL_TENSOR .ATTN_Q_NORM : (
189
205
"language_model.encoder.layers.{bid}.self_attention.q_layernorm" ,
190
206
),
@@ -213,11 +229,14 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
213
229
for tensor , keys in self .block_mappings_cfg .items ():
214
230
if tensor not in MODEL_TENSORS [arch ]:
215
231
continue
216
- tensor_name = TENSOR_NAMES [tensor ].format (bid = bid )
217
- self .mapping [tensor_name ] = (tensor , tensor_name )
218
- for key in keys :
219
- key = key .format (bid = bid )
220
- self .mapping [key ] = (tensor , tensor_name )
232
+ # TODO: make this configurable
233
+ n_experts = 8
234
+ for xid in range (n_experts ):
235
+ tensor_name = TENSOR_NAMES [tensor ].format (bid = bid , xid = xid )
236
+ self .mapping [tensor_name ] = (tensor , tensor_name )
237
+ for key in keys :
238
+ key = key .format (bid = bid , xid = xid )
239
+ self .mapping [key ] = (tensor , tensor_name )
221
240
222
241
def get_type_and_name (self , key : str , try_suffixes : Sequence [str ] = ()) -> tuple [MODEL_TENSOR , str ] | None :
223
242
result = self .mapping .get (key )
0 commit comments