Skip to content

Commit 9bfadbb

Browse files
jinhongyiijunrushao
andcommitted
Introduce Mixtral MoE Model
This PR introduces support for Mixtral MoE models with MLC's latest SLM quantization/compilation pipeline. It includes the following pieces of changes: **Operators.** We implemented a list of operators in TIR's TVMScript format in two files `moe_misc` and `moe_matmul`. Those TIR kernels implement "transpose indices" and "blocked-CSR-COO" as described in MegaBlock [1]. `moe_misc.py` primarily concerns sparsity-related operators, including: - `get_indices`, `get_indptr` and `scatter_output`: CSR-style index manipulation and array shuffling that makes the input ranges each expert has to deal with contiguous. - `moe_sum`, `moe_cumsum`, `topk` which are standard operators but specialized for MoE usecases, e.g. #experts and #activated-experts are small. `moe_matmul.py` includes non-quantized and quantized GEMV and GEMV operators used in MoE model serving. Typically, in single batch decoding, GEMV operators should suffice, but group GEMM is a necessary dependency in both prefilling and batched decoding. **Model architecture.** We reuse the attention blocking block from Mistral, and implemented MLP MoE in `mixtral_model.py`. In Mixtral, there are three groups of experts in each MLP, where `e1` and `e3` are gate/up projections (project-in) and `e2` is down project (project-out). **Weight quantization.** We batch all experts of the same kind into a single tensor, whose shape is `(Ne, N, K)`, where `Ne` is the total number of experts, `N` is out features and `K` is in-features. Applying group quantization, we compress along the `K` dimension as consistent with the rest of the project. **Performance.** The current TIR is highly optimized for non-tensor core scenarios (Metal, WebGPU, non-TensorCore CUDA, AMD, etc) and tensor core performance is left for a PR in the nearest future. **Try out MLC's Mixtral Model.** The int4-quantized Mixtral model has 24.5G of parameters. ```python from mlc_chat import ChatConfig, ChatModule, callback from mlc_chat.support import logging logging.enable_logging() MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC" NUM_GPU = 1 def main(): cm = ChatModule(MODEL, device="cuda:0", chat_config=ChatConfig( sliding_window_size=1024, tensor_parallel_shards=NUM_GPU, )) cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2)) if __name__ == "__main__": main() ``` Quantization formats: - 3-bit (19.662 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC) - 4-bit (24.466 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC) The 3-bit version can be run comfortably using a 24G GPU (e.g. 4090, 3090Ti). **Convert Mixtral to MLC format from scratch.** The following instructions are only needed for advanced users to quantize Mixtral from scratch. ```bash SRC_DIR=/path/to/Mixtral-8x7B-v0.1 # raw model downloaded from HuggingFace MODEL_DIR=/mlc_models/mixtral-q4f16_1 # destination directory mlc_chat gen_config $SRC_DIR -o $MODEL_DIR --quantization q4f16_1 \ --conv-template LM # "LM" (lang model) means no conversation template yet mlc_chat convert_weight $SRC_DIR --quantization q4f16_1 -o $MODEL_DIR ``` [1] Gale, Trevor, Deepak Narayanan, Cliff Young, and Matei Zaharia. "MegaBlocks: Efficient Sparse Training with Mixture-of-Experts." Proceedings of MLSys 2023. Co-authored-by: Junru Shao <[email protected]>
1 parent 5e23900 commit 9bfadbb

File tree

21 files changed

+1568
-196
lines changed

21 files changed

+1568
-196
lines changed

python/mlc_chat/compiler_pass/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
5858

5959
@register_pipeline("mlc_llm")
6060
def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
61-
cublas_gemm: bool,
61+
cublas_gemm: bool = False,
6262
variable_bounds: Dict[str, int] = None,
6363
additional_tirs: Dict[str, tvm.tir.PrimFunc] = None,
6464
metadata: Dict[str, Any] = None,

python/mlc_chat/interface/compile.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,18 @@ def _apply_preproc_to_params(
8585

8686
def _compile(args: CompileArgs, model_config: ConfigBase):
8787
def _get_variable_bounds(model_config) -> Dict[str, int]:
88-
variable_bounds = {"seq_len": model_config.prefill_chunk_size}
8988
if hasattr(model_config, "sliding_window_size"):
90-
variable_bounds["rolling_cache_len"] = model_config.sliding_window_size
91-
variable_bounds["kv_seq_len"] = (
92-
model_config.sliding_window_size + model_config.prefill_chunk_size,
93-
)
94-
else:
95-
variable_bounds["total_seq_len"] = model_config.context_window_size
96-
variable_bounds["batch_size"] = getattr(model_config, "max_batch_size", 1)
97-
return variable_bounds
89+
return {
90+
"rolling_cache_len": model_config.sliding_window_size,
91+
"kv_seq_len": model_config.sliding_window_size + model_config.prefill_chunk_size,
92+
"seq_len": model_config.prefill_chunk_size,
93+
"batch_size": getattr(model_config, "max_batch_size", 1),
94+
}
95+
return {
96+
"total_seq_len": model_config.context_window_size,
97+
"seq_len": model_config.prefill_chunk_size,
98+
"batch_size": getattr(model_config, "max_batch_size", 1),
99+
}
98100

99101
def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
100102
return {

python/mlc_chat/interface/convert_weight.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def _device_to_str(device: Device) -> str:
5151

5252

5353
def _calc_total_params(model: nn.Module) -> int:
54-
_, named_params, _ = model.export_tvm(spec=model.get_default_spec(), allow_extern=True)
54+
_, named_params, _ = model.export_tvm( # type: ignore[misc]
55+
spec=model.get_default_spec(), # type: ignore[attr-defined]
56+
allow_extern=True,
57+
)
5558
total_params = 0
5659
for _, param in named_params:
5760
total_params += math.prod(param.shape)

python/mlc_chat/model/gpt2/gpt2_model.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
TODO: add docstring
44
"""
55
import dataclasses
6-
import math
76
from typing import Any, Dict, Optional
87

98
from tvm import te, tir
109
from tvm.relax.frontend import nn
1110
from tvm.relax.frontend.nn import Tensor, op
1211

12+
from mlc_chat import op as op_ext
1313
from mlc_chat.support import logging
1414
from mlc_chat.support.config import ConfigBase
1515
from mlc_chat.support.style import bold
@@ -110,29 +110,15 @@ def forward(
110110

111111
self.k_cache.append(op.squeeze(k, axis=0))
112112
self.v_cache.append(op.squeeze(v, axis=0))
113-
k = op.reshape(self.k_cache.view(t), (b, t, h, d))
114-
v = op.reshape(self.v_cache.view(t), (b, t, h, d))
115-
116-
q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d]
117-
k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
118-
v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
119-
120-
attn_weights = op.matmul(
121-
q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t]
122-
) / math.sqrt(d)
113+
k = self.k_cache.view(t)
114+
v = self.v_cache.view(t)
123115

124116
if self.scale_attn_by_inverse_layer_idx:
125-
attn_weights = attn_weights / float(self.layer_idx + 1)
126-
127-
dtype = attn_weights.dtype
128-
attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask)
129-
if dtype == "float32":
130-
attn_weights = op.softmax(attn_weights, axis=-1)
117+
attn_score_scaling_factor = 1.0 / float(self.layer_idx + 1)
131118
else:
132-
attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype)
133-
# [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d]
134-
output = op.matmul(attn_weights, v)
135-
return self.c_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h * d)))
119+
attn_score_scaling_factor = 1.0
120+
output = op_ext.attention(q, k, v, attention_mask, attn_score_scaling_factor)
121+
return self.c_proj(output)
136122

137123

138124
class GPT2MLP(nn.Module):

python/mlc_chat/model/llama/llama_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tvm.relax.frontend.nn import Tensor, op
1111

1212
from mlc_chat import op as op_ext
13-
from mlc_chat.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache
13+
from mlc_chat.nn import FlashInferPagedKVCache, PagedKVCache
1414
from mlc_chat.support import logging
1515
from mlc_chat.support import tensor_parallel as tp
1616
from mlc_chat.support.config import ConfigBase
@@ -342,7 +342,7 @@ def create_flashinfer_paged_kv_cache(
342342
num_kv_heads = self.num_key_value_heads // self.tensor_parallel_shards
343343
# Note: Right now we only have FlashInfer-based KV cache supported.
344344
# TIR version will be introduced soon.
345-
return FlashInferPagedKVCache.create(
345+
return FlashInferPagedKVCache(
346346
max_batch_size=max_batch_size,
347347
max_total_seq_len=max_total_seq_len,
348348
page_size=page_size,

python/mlc_chat/model/mistral/mistral_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def __init__(self, config: MistralConfig):
358358
[MistralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)]
359359
)
360360
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
361-
self.tensor_parallel_shards = config.tensor_parallel_shards > 1
361+
self.tensor_parallel_shards = config.tensor_parallel_shards
362362

363363
def forward( # pylint: disable=too-many-arguments
364364
self,

python/mlc_chat/model/mixtral/__init__.py

Whitespace-only changes.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
This file specifies how MLC's Mixtral parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
import functools
6+
7+
import numpy as np
8+
9+
from mlc_chat.loader import ExternMapping
10+
from mlc_chat.quantization import Quantization
11+
12+
from .mixtral_model import MixtralConfig, MixtralForCasualLM
13+
14+
15+
def huggingface(model_config: MixtralConfig, quantization: Quantization) -> ExternMapping:
16+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
17+
the names of HuggingFace PyTorch parameters.
18+
19+
Parameters
20+
----------
21+
model_config : MixtralConfig
22+
The configuration of the Mixtral model.
23+
24+
quantization : Quantization
25+
The quantization configuration.
26+
27+
Returns
28+
-------
29+
param_map : ExternMapping
30+
The parameter mapping from MLC to HuggingFace PyTorch.
31+
"""
32+
model = MixtralForCasualLM(model_config)
33+
if quantization is not None:
34+
model.to(quantization.model_dtype)
35+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
36+
spec=model.get_default_spec(),
37+
allow_extern=True,
38+
)
39+
named_parameters = dict(_named_params)
40+
41+
mapping = ExternMapping()
42+
43+
for i in range(model_config.num_hidden_layers):
44+
# Add QKV in self attention
45+
attn = f"model.layers.{i}.self_attn"
46+
mlc_name = f"{attn}.qkv_proj.weight"
47+
mlc_param = named_parameters[mlc_name]
48+
mapping.add_mapping(
49+
mlc_name,
50+
[
51+
f"{attn}.q_proj.weight",
52+
f"{attn}.k_proj.weight",
53+
f"{attn}.v_proj.weight",
54+
],
55+
functools.partial(
56+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
57+
dtype=mlc_param.dtype,
58+
),
59+
)
60+
61+
# Add gates in MLP (when MoE is enabled)
62+
mlp = f"model.layers.{i}.block_sparse_moe"
63+
mlc_mlp = f"model.layers.{i}.moe"
64+
mlc_name = f"{mlc_mlp}.e1_e3.weight"
65+
mlc_param = named_parameters[mlc_name]
66+
67+
def combine_expert_gate_up(*hf_params, dtype):
68+
stack = []
69+
for i in range(0, len(hf_params), 2):
70+
stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))
71+
return np.stack(stack, axis=0).astype(dtype)
72+
73+
mapping.add_mapping(
74+
mlc_name,
75+
functools.reduce(
76+
lambda a, b: a + b,
77+
[
78+
[
79+
f"{mlp}.experts.{expert_id}.w1.weight",
80+
f"{mlp}.experts.{expert_id}.w3.weight",
81+
]
82+
for expert_id in range(model_config.num_local_experts)
83+
],
84+
),
85+
functools.partial(
86+
combine_expert_gate_up,
87+
dtype=mlc_param.dtype,
88+
),
89+
)
90+
91+
mlc_name = f"{mlc_mlp}.e2.weight"
92+
mlc_param = named_parameters[mlc_name]
93+
mapping.add_mapping(
94+
mlc_name,
95+
[
96+
f"{mlp}.experts.{expert_id}.w2.weight"
97+
for expert_id in range(model_config.num_local_experts)
98+
],
99+
functools.partial(
100+
lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),
101+
dtype=mlc_param.dtype,
102+
),
103+
)
104+
105+
mlc_name = f"{mlc_mlp}.gate.weight"
106+
mlc_param = named_parameters[mlc_name]
107+
mapping.add_mapping(
108+
mlc_name,
109+
[f"{mlp}.gate.weight"],
110+
functools.partial(
111+
lambda x, dtype: x.astype(dtype),
112+
dtype=mlc_param.dtype,
113+
),
114+
)
115+
116+
# inv_freq is not used in the model
117+
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")
118+
119+
for mlc_name, mlc_param in named_parameters.items():
120+
if mlc_name not in mapping.param_map:
121+
mapping.add_mapping(
122+
mlc_name,
123+
[mlc_name],
124+
functools.partial(
125+
lambda x, dtype: x.astype(dtype),
126+
dtype=mlc_param.dtype,
127+
),
128+
)
129+
return mapping

0 commit comments

Comments
 (0)