Skip to content

Commit a8caaf2

Browse files
author
Siyuan Feng
committed
Support Qwen2-MoE Architecture
1 parent b7416c0 commit a8caaf2

File tree

7 files changed

+608
-8
lines changed

7 files changed

+608
-8
lines changed

python/mlc_llm/model/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .phi import phi_loader, phi_model, phi_quantization
2323
from .qwen import qwen_loader, qwen_model, qwen_quantization
2424
from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization
25+
from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization
2526
from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization
2627
from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization
2728
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization
@@ -225,6 +226,20 @@ class Model:
225226
"ft-quant": qwen2_quantization.ft_quant,
226227
},
227228
),
229+
"qwen2_moe": Model(
230+
name="qwen2_moe",
231+
model=qwen2_moe_model.Qwen2MoeForCausalLM,
232+
config=qwen2_moe_model.Qwen2MoeConfig,
233+
source={
234+
"huggingface-torch": qwen2_moe_loader.huggingface,
235+
"huggingface-safetensor": qwen2_moe_loader.huggingface,
236+
},
237+
quantize={
238+
"no-quant": qwen2_moe_quantization.no_quant,
239+
"group-quant": qwen2_moe_quantization.group_quant,
240+
"ft-quant": qwen2_moe_quantization.ft_quant,
241+
},
242+
),
228243
"stablelm": Model(
229244
name="stablelm",
230245
model=stablelm_model.StableLmForCausalLM,

python/mlc_llm/model/model_preset.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,39 @@
416416
"use_sliding_window": False,
417417
"vocab_size": 151936,
418418
},
419+
"qwen2moe": {
420+
"architectures": ["Qwen2MoeForCausalLM"],
421+
"attention_dropout": 0.0,
422+
"bos_token_id": 151643,
423+
"eos_token_id": 151645,
424+
"hidden_act": "silu",
425+
"hidden_size": 2048,
426+
"initializer_range": 0.02,
427+
"intermediate_size": 5632,
428+
"max_position_embeddings": 32768,
429+
"max_window_layers": 21,
430+
"model_type": "qwen2_moe",
431+
"num_attention_heads": 16,
432+
"num_hidden_layers": 24,
433+
"num_key_value_heads": 16,
434+
"rms_norm_eps": 1e-06,
435+
"rope_theta": 1000000.0,
436+
"sliding_window": 32768,
437+
"tie_word_embeddings": False,
438+
"torch_dtype": "bfloat16",
439+
"transformers_version": "4.39.0.dev0",
440+
"use_cache": True,
441+
"use_sliding_window": False,
442+
"vocab_size": 151936,
443+
"decoder_sparse_step": 1,
444+
"moe_intermediate_size": 1408,
445+
"shared_expert_intermediate_size": 5632,
446+
"num_experts_per_tok": 4,
447+
"num_experts": 60,
448+
"norm_topk_prob": False,
449+
"output_router_logits": False,
450+
"router_aux_loss_coef": 0.001,
451+
},
419452
"stablelm": {
420453
"architectures": ["StableLmForCausalLM"],
421454
"bos_token_id": 0,

python/mlc_llm/model/qwen2_moe/__init__.py

Whitespace-only changes.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
This file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .qwen2_moe_model import Qwen2MoeConfig, Qwen2MoeForCausalLM
14+
15+
16+
def huggingface(model_config: Qwen2MoeConfig, quantization: Quantization) -> ExternMapping:
17+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
18+
the names of HuggingFace PyTorch parameters.
19+
20+
Parameters
21+
----------
22+
model_config : QWen2Config
23+
The configuration of the GPT-2 model.
24+
25+
quantization : Quantization
26+
The quantization configuration.
27+
28+
Returns
29+
-------
30+
param_map : ExternMapping
31+
The parameter mapping from MLC to HuggingFace PyTorch.
32+
"""
33+
model = Qwen2MoeForCausalLM(model_config)
34+
if quantization is not None:
35+
model.to(quantization.model_dtype)
36+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
37+
spec=model.get_default_spec(),
38+
allow_extern=True,
39+
)
40+
named_parameters = dict(_named_params)
41+
42+
mapping = ExternMapping()
43+
44+
for i in range(model_config.num_hidden_layers):
45+
# map attention weight
46+
attn = f"model.layers.{i}.self_attn"
47+
for weight_type in ["weight", "bias"]:
48+
mlc_name = f"{attn}.c_attn.{weight_type}"
49+
mlc_param = named_parameters[mlc_name]
50+
mapping.add_mapping(
51+
mlc_name,
52+
[
53+
f"{attn}.q_proj.{weight_type}",
54+
f"{attn}.k_proj.{weight_type}",
55+
f"{attn}.v_proj.{weight_type}",
56+
],
57+
functools.partial(
58+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
59+
dtype=mlc_param.dtype,
60+
),
61+
)
62+
# map mlp shared expert weight
63+
mlp = f"model.layers.{i}.mlp"
64+
shared_expert = f"{mlp}.shared_expert"
65+
mlc_name = f"{shared_expert}.gate_up_proj.weight"
66+
mlc_param = named_parameters[mlc_name]
67+
mapping.add_mapping(
68+
mlc_name,
69+
[
70+
f"{shared_expert}.gate_proj.weight",
71+
f"{shared_expert}.up_proj.weight",
72+
],
73+
functools.partial(
74+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
75+
dtype=mlc_param.dtype,
76+
),
77+
)
78+
# map mlp moe gate and up weight
79+
mlc_name = f"{mlp}.moe_gate_up_proj.weight"
80+
81+
def combine_expert_gate_up(*hf_params, dtype):
82+
stack = []
83+
for i in range(0, len(hf_params), 2):
84+
stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))
85+
return np.stack(stack, axis=0).astype(dtype)
86+
87+
mapping.add_mapping(
88+
mlc_name,
89+
functools.reduce(
90+
lambda a, b: a + b,
91+
[
92+
[
93+
f"{mlp}.experts.{expert_id}.gate_proj.weight",
94+
f"{mlp}.experts.{expert_id}.up_proj.weight",
95+
]
96+
for expert_id in range(model_config.num_experts)
97+
],
98+
),
99+
functools.partial(
100+
combine_expert_gate_up,
101+
dtype=mlc_param.dtype,
102+
),
103+
)
104+
105+
# map mlp moe gate and up weight
106+
mlc_name = f"{mlp}.moe_down_proj.weight"
107+
mlc_param = named_parameters[mlc_name]
108+
mapping.add_mapping(
109+
mlc_name,
110+
[
111+
f"{mlp}.experts.{expert_id}.down_proj.weight"
112+
for expert_id in range(model_config.num_experts)
113+
],
114+
functools.partial(
115+
lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),
116+
dtype=mlc_param.dtype,
117+
),
118+
)
119+
120+
for mlc_name, mlc_param in named_parameters.items():
121+
if mlc_name not in mapping.param_map:
122+
mapping.add_mapping(
123+
mlc_name,
124+
[mlc_name],
125+
functools.partial(
126+
lambda x, dtype: x.astype(dtype),
127+
dtype=mlc_param.dtype,
128+
),
129+
)
130+
return mapping

0 commit comments

Comments
 (0)