From 278ec41ec90d61c849689df6356ed2b912995147 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Wed, 3 Jan 2024 15:39:08 -0500 Subject: [PATCH] Introduce Mixtral MoE Model This PR introduces support for Mixtral MoE models with MLC's latest SLM quantization/compilation pipeline. It includes the following pieces of changes: **Operators.** We implemented a list of operators in TIR's TVMScript format in two files `moe_misc` and `moe_matmul`. Those TIR kernels implement "transpose indices" and "blocked-CSR-COO" as described in MegaBlock [1]. `moe_misc.py` primarily concerns sparsity-related operators, including: - `get_indices`, `get_indptr` and `scatter_output`: CSR-style index manipulation and array shuffling that makes the input ranges each expert has to deal with contiguous. - `moe_sum`, `moe_cumsum`, `topk` which are standard operators but specialized for MoE usecases, e.g. #experts and #activated-experts are small. `moe_matmul.py` includes non-quantized and quantized GEMV and GEMV operators used in MoE model serving. Typically, in single batch decoding, GEMV operators should suffice, but group GEMM is a necessary dependency in both prefilling and batched decoding. **Model architecture.** We reuse the attention blocking block from Mistral, and implemented MLP MoE in `mixtral_model.py`. In Mixtral, there are three groups of experts in each MLP, where `e1` and `e3` are gate/up projections (project-in) and `e2` is down project (project-out). **Weight quantization.** We batch all experts of the same kind into a single tensor, whose shape is `(Ne, N, K)`, where `Ne` is the total number of experts, `N` is out features and `K` is in-features. Applying group quantization, we compress along the `K` dimension as consistent with the rest of the project. **Performance.** The current TIR is highly optimized for non-tensor core scenarios (Metal, WebGPU, non-TensorCore CUDA, AMD, etc) and tensor core performance is left for a PR in the nearest future. **Try out MLC's Mixtral Model.** The int4-quantized Mixtral model has 24.5G of parameters. ```python from mlc_chat import ChatConfig, ChatModule, callback from mlc_chat.support import logging logging.enable_logging() MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC" NUM_GPU = 1 def main(): cm = ChatModule(MODEL, device="cuda:0", chat_config=ChatConfig( sliding_window_size=1024, tensor_parallel_shards=NUM_GPU, )) cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2)) if __name__ == "__main__": main() ``` Quantization formats: - 3-bit (19.662 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC) - 4-bit (24.466 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC) The 3-bit version can be run comfortably using a 24G GPU (e.g. 4090, 3090Ti). **Convert Mixtral to MLC format from scratch.** The following instructions are only needed for advanced users to quantize Mixtral from scratch. ```bash SRC_DIR=/path/to/Mixtral-8x7B-v0.1 # raw model downloaded from HuggingFace MODEL_DIR=/mlc_models/mixtral-q4f16_1 # destination directory mlc_chat gen_config $SRC_DIR -o $MODEL_DIR --quantization q4f16_1 \ --conv-template LM # "LM" (lang model) means no conversation template yet mlc_chat convert_weight $SRC_DIR --quantization q4f16_1 -o $MODEL_DIR ``` [1] Gale, Trevor, Deepak Narayanan, Cliff Young, and Matei Zaharia. "MegaBlocks: Efficient Sparse Training with Mixture-of-Experts." Proceedings of MLSys 2023. Co-authored-by: Junru Shao --- python/mlc_chat/compiler_pass/pipeline.py | 2 +- python/mlc_chat/interface/compile.py | 20 +- python/mlc_chat/interface/convert_weight.py | 5 +- python/mlc_chat/model/gpt2/gpt2_model.py | 28 +- python/mlc_chat/model/llama/llama_model.py | 4 +- .../mlc_chat/model/mistral/mistral_model.py | 2 +- python/mlc_chat/model/mixtral/__init__.py | 0 .../mlc_chat/model/mixtral/mixtral_loader.py | 129 +++++ .../mlc_chat/model/mixtral/mixtral_model.py | 174 ++++++ .../model/mixtral/mixtral_quantization.py | 45 ++ python/mlc_chat/model/model.py | 14 + python/mlc_chat/nn/__init__.py | 3 +- python/mlc_chat/nn/expert.py | 24 + python/mlc_chat/nn/kv_cache.py | 174 ++++-- python/mlc_chat/op/__init__.py | 1 + python/mlc_chat/op/attention.py | 20 +- python/mlc_chat/op/kv_cache.py | 100 ---- python/mlc_chat/op/moe_matmul.py | 548 ++++++++++++++++++ python/mlc_chat/op/moe_misc.py | 350 +++++++++++ .../quantization/group_quantization.py | 107 ++++ tests/python/model/test_kv_cache.py | 15 +- 21 files changed, 1569 insertions(+), 196 deletions(-) create mode 100644 python/mlc_chat/model/mixtral/__init__.py create mode 100644 python/mlc_chat/model/mixtral/mixtral_loader.py create mode 100644 python/mlc_chat/model/mixtral/mixtral_model.py create mode 100644 python/mlc_chat/model/mixtral/mixtral_quantization.py create mode 100644 python/mlc_chat/nn/expert.py delete mode 100644 python/mlc_chat/op/kv_cache.py create mode 100644 python/mlc_chat/op/moe_matmul.py create mode 100644 python/mlc_chat/op/moe_misc.py diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index def81d26e5..9e3378e48d 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -58,7 +58,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR @register_pipeline("mlc_llm") def _mlc_llm_pipeline( # pylint: disable=too-many-arguments - cublas_gemm: bool, + cublas_gemm: bool = False, variable_bounds: Dict[str, int] = None, additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, metadata: Dict[str, Any] = None, diff --git a/python/mlc_chat/interface/compile.py b/python/mlc_chat/interface/compile.py index a0aa3f3d07..a483a055e7 100644 --- a/python/mlc_chat/interface/compile.py +++ b/python/mlc_chat/interface/compile.py @@ -85,16 +85,18 @@ def _apply_preproc_to_params( def _compile(args: CompileArgs, model_config: ConfigBase): def _get_variable_bounds(model_config) -> Dict[str, int]: - variable_bounds = {"seq_len": model_config.prefill_chunk_size} if hasattr(model_config, "sliding_window_size"): - variable_bounds["rolling_cache_len"] = model_config.sliding_window_size - variable_bounds["kv_seq_len"] = ( - model_config.sliding_window_size + model_config.prefill_chunk_size, - ) - else: - variable_bounds["total_seq_len"] = model_config.context_window_size - variable_bounds["batch_size"] = getattr(model_config, "max_batch_size", 1) - return variable_bounds + return { + "rolling_cache_len": model_config.sliding_window_size, + "kv_seq_len": model_config.sliding_window_size + model_config.prefill_chunk_size, + "seq_len": model_config.prefill_chunk_size, + "batch_size": getattr(model_config, "max_batch_size", 1), + } + return { + "total_seq_len": model_config.context_window_size, + "seq_len": model_config.prefill_chunk_size, + "batch_size": getattr(model_config, "max_batch_size", 1), + } def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: return { diff --git a/python/mlc_chat/interface/convert_weight.py b/python/mlc_chat/interface/convert_weight.py index 2e995ee4f0..3635d707d0 100644 --- a/python/mlc_chat/interface/convert_weight.py +++ b/python/mlc_chat/interface/convert_weight.py @@ -51,7 +51,10 @@ def _device_to_str(device: Device) -> str: def _calc_total_params(model: nn.Module) -> int: - _, named_params, _ = model.export_tvm(spec=model.get_default_spec(), allow_extern=True) + _, named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) total_params = 0 for _, param in named_params: total_params += math.prod(param.shape) diff --git a/python/mlc_chat/model/gpt2/gpt2_model.py b/python/mlc_chat/model/gpt2/gpt2_model.py index 2a43cfb7f6..7db869aa13 100644 --- a/python/mlc_chat/model/gpt2/gpt2_model.py +++ b/python/mlc_chat/model/gpt2/gpt2_model.py @@ -3,13 +3,13 @@ TODO: add docstring """ import dataclasses -import math from typing import Any, Dict, Optional from tvm import te, tir from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op +from mlc_chat import op as op_ext from mlc_chat.support import logging from mlc_chat.support.config import ConfigBase from mlc_chat.support.style import bold @@ -110,29 +110,15 @@ def forward( self.k_cache.append(op.squeeze(k, axis=0)) self.v_cache.append(op.squeeze(v, axis=0)) - k = op.reshape(self.k_cache.view(t), (b, t, h, d)) - v = op.reshape(self.v_cache.view(t), (b, t, h, d)) - - q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d] - k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d] - v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d] - - attn_weights = op.matmul( - q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t] - ) / math.sqrt(d) + k = self.k_cache.view(t) + v = self.v_cache.view(t) if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) - - dtype = attn_weights.dtype - attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask) - if dtype == "float32": - attn_weights = op.softmax(attn_weights, axis=-1) + attn_score_scaling_factor = 1.0 / float(self.layer_idx + 1) else: - attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype) - # [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d] - output = op.matmul(attn_weights, v) - return self.c_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h * d))) + attn_score_scaling_factor = 1.0 + output = op_ext.attention(q, k, v, attention_mask, attn_score_scaling_factor) + return self.c_proj(output) class GPT2MLP(nn.Module): diff --git a/python/mlc_chat/model/llama/llama_model.py b/python/mlc_chat/model/llama/llama_model.py index 47f37d07eb..60d4693a03 100644 --- a/python/mlc_chat/model/llama/llama_model.py +++ b/python/mlc_chat/model/llama/llama_model.py @@ -10,7 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_chat import op as op_ext -from mlc_chat.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache +from mlc_chat.nn import FlashInferPagedKVCache, PagedKVCache from mlc_chat.support import logging from mlc_chat.support import tensor_parallel as tp from mlc_chat.support.config import ConfigBase @@ -342,7 +342,7 @@ def create_flashinfer_paged_kv_cache( num_kv_heads = self.num_key_value_heads // self.tensor_parallel_shards # Note: Right now we only have FlashInfer-based KV cache supported. # TIR version will be introduced soon. - return FlashInferPagedKVCache.create( + return FlashInferPagedKVCache( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, page_size=page_size, diff --git a/python/mlc_chat/model/mistral/mistral_model.py b/python/mlc_chat/model/mistral/mistral_model.py index 20ec99f524..92e512705e 100644 --- a/python/mlc_chat/model/mistral/mistral_model.py +++ b/python/mlc_chat/model/mistral/mistral_model.py @@ -358,7 +358,7 @@ def __init__(self, config: MistralConfig): [MistralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] ) self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) - self.tensor_parallel_shards = config.tensor_parallel_shards > 1 + self.tensor_parallel_shards = config.tensor_parallel_shards def forward( # pylint: disable=too-many-arguments self, diff --git a/python/mlc_chat/model/mixtral/__init__.py b/python/mlc_chat/model/mixtral/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_chat/model/mixtral/mixtral_loader.py b/python/mlc_chat/model/mixtral/mixtral_loader.py new file mode 100644 index 0000000000..12e96ebad2 --- /dev/null +++ b/python/mlc_chat/model/mixtral/mixtral_loader.py @@ -0,0 +1,129 @@ +""" +This file specifies how MLC's Mixtral parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .mixtral_model import MixtralConfig, MixtralForCasualLM + + +def huggingface(model_config: MixtralConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : MixtralConfig + The configuration of the Mixtral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = MixtralForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Add gates in MLP (when MoE is enabled) + mlp = f"model.layers.{i}.block_sparse_moe" + mlc_mlp = f"model.layers.{i}.moe" + mlc_name = f"{mlc_mlp}.e1_e3.weight" + mlc_param = named_parameters[mlc_name] + + def combine_expert_gate_up(*hf_params, dtype): + stack = [] + for i in range(0, len(hf_params), 2): + stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0)) + return np.stack(stack, axis=0).astype(dtype) + + mapping.add_mapping( + mlc_name, + functools.reduce( + lambda a, b: a + b, + [ + [ + f"{mlp}.experts.{expert_id}.w1.weight", + f"{mlp}.experts.{expert_id}.w3.weight", + ] + for expert_id in range(model_config.num_local_experts) + ], + ), + functools.partial( + combine_expert_gate_up, + dtype=mlc_param.dtype, + ), + ) + + mlc_name = f"{mlc_mlp}.e2.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.experts.{expert_id}.w2.weight" + for expert_id in range(model_config.num_local_experts) + ], + functools.partial( + lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + mlc_name = f"{mlc_mlp}.gate.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [f"{mlp}.gate.weight"], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/mixtral/mixtral_model.py b/python/mlc_chat/model/mixtral/mixtral_model.py new file mode 100644 index 0000000000..83fcdf2a6a --- /dev/null +++ b/python/mlc_chat/model/mixtral/mixtral_model.py @@ -0,0 +1,174 @@ +"""Implementation for Mistral architecture.""" +import dataclasses + +from tvm import tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.model.mistral.mistral_model import ( + MistralAttention, + MistralConfig, + MistralForCasualLM, + MistralModel, + RotaryEmbedding, +) +from mlc_chat.nn.expert import MixtralExperts +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class MixtralConfig(MistralConfig): # pylint: disable=too-many-instance-attributes + """Configuration of the Mixtral model.""" + + num_local_experts: int = 0 + num_experts_per_tok: int = 0 + + +# pylint: disable=invalid-name,missing-docstring,too-many-locals,fixme + + +class MixtralMoE(nn.Module): + """Mixture of experts""" + + def __init__(self, config: MixtralConfig): + super().__init__() + self.num_experts_per_tok = config.num_experts_per_tok + self.num_local_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate = nn.Linear( + in_features=config.hidden_size, + out_features=config.num_local_experts, + bias=False, + ) + self.e1_e3 = MixtralExperts( + self.num_local_experts, + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + ) + self.e2 = MixtralExperts( + self.num_local_experts, + in_features=self.intermediate_size, + out_features=config.hidden_size, + ) + self.dtype = "float32" + + def forward(self, x: Tensor): + def _expert_forward(x: Tensor, indptr: Tensor): + x1_x3 = self.e1_e3(x, indptr) + x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1) + x = self.e2(op.silu(x1) * x3, indptr) + return x + + experts_per_tok = self.num_experts_per_tok # activated experts per token + local_experts = self.num_local_experts # total number of experts + batch_size, seq_len, hidden_size = x.shape + num_tokens = batch_size * seq_len + x = x.reshape(num_tokens, hidden_size) + # gate: [num_tokens, local_experts] + gate: Tensor = self.gate(x) + # expert_weights: [num_tokens, experts_per_tok] + # expert_indices: [num_tokens, experts_per_tok] + expert_weights, expert_indices = op_ext.moe_misc.topk(gate, experts_per_tok) + expert_weights = op.softmax(expert_weights.astype("float32"), axis=-1).astype(self.dtype) + if num_tokens == 1: + # x: [num_tokens * experts_per_tok, hidden_size] + x = _expert_forward(x, expert_indices) + else: + # cumsum: [num_tokens * local_experts] + cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, local_experts) + # indices: [num_tokens * experts_per_tok] + indices = op_ext.moe_misc.get_indices(cumsum, expert_indices) + # indptr: [num_local_experts + 1] + indptr = op_ext.moe_misc.get_indptr(cumsum, local_experts, num_tokens) + # x: [num_tokens * experts_per_tok, hidden_size] + x = op.take(x, indices / experts_per_tok, axis=0) + x = _expert_forward(x, indptr) + x = op_ext.moe_misc.scatter_output(x, indices) + # x: [num_tokens, experts_per_tok, hidden_size] + x = x.reshape( # pylint: disable=too-many-function-args + num_tokens, experts_per_tok, hidden_size + ) * expert_weights.reshape( # pylint: disable=too-many-function-args + num_tokens, experts_per_tok, 1 + ) + # x: [num_tokens, hidden_size] + x = op_ext.moe_misc.moe_sum(x, dim=1) + x = x.reshape(batch_size, seq_len, hidden_size) # pylint: disable=too-many-function-args + return x + + +class MixtralDecoderLayer(nn.Module): + """Mixtral decoder layer""" + + def __init__(self, config: MixtralConfig, rotary_embedding: RotaryEmbedding): + eps = config.rms_norm_eps + self.self_attn = MistralAttention(config, rotary_embedding) + self.moe = MixtralMoE(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.moe.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.moe.e1_e3, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=1)) + _set(self.moe.e2, tp.ShardSingleDim("_shard_mlp_down", dim=2)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward( # pylint: disable=too-many-arguments + self, + hidden_states: Tensor, + attention_mask: Tensor, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + cache_offset: tir.Var, + ): + """Forward pass of a decoder layer; calculate attention, and add an residual connection.""" + + def _apply_residual(out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum") + return out + residual + + out = self.self_attn( + self.input_layernorm(hidden_states), + attention_mask, + rolling_cache_len, + kv_seq_len, + cache_offset, + ) + hidden_states = _apply_residual(out, residual=hidden_states) + out = self.moe(self.post_attention_layernorm(hidden_states)) + hidden_states = _apply_residual(out, residual=hidden_states) + return hidden_states + + +class MixtralModel(MistralModel): + """Exact same as LlamaModel.""" + + def __init__(self, config: MixtralConfig): + super().__init__(config) + rotary_embedding = RotaryEmbedding(config) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + ) + + +class MixtralForCasualLM(MistralForCasualLM): + """Same as LlamaForCausalLM, except for the use of sliding window attention.""" + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.model = MixtralModel(config) diff --git a/python/mlc_chat/model/mixtral/mixtral_quantization.py b/python/mlc_chat/model/mixtral/mixtral_quantization.py new file mode 100644 index 0000000000..9435d9e234 --- /dev/null +++ b/python/mlc_chat/model/mixtral/mixtral_quantization.py @@ -0,0 +1,45 @@ +"""This file specifies how MLC's Mistral parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import AWQQuantize, GroupQuantize, NoQuantize + +from .mixtral_model import MixtralConfig, MixtralForCasualLM + + +def group_quant( + model_config: MixtralConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral-architecture model using group quantization.""" + model: nn.Module = MixtralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: MixtralConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral-architecture model using Activation-aware Weight Quantization(AWQ).""" + raise NotImplementedError("AWQ is not implemented for Mixtral models.") + + +def no_quant( + model_config: MixtralConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral model without quantization.""" + model: nn.Module = MixtralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/model.py b/python/mlc_chat/model/model.py index ff98fb49da..dca6eaaee6 100644 --- a/python/mlc_chat/model/model.py +++ b/python/mlc_chat/model/model.py @@ -12,6 +12,7 @@ from .gpt_neox import gpt_neox_loader, gpt_neox_model, gpt_neox_quantization from .llama import llama_loader, llama_model, llama_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization +from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization from .phi import phi_loader, phi_model, phi_quantization ModelConfig = Any @@ -99,6 +100,19 @@ class Model: "group-quant": gpt2_quantization.group_quant, }, ), + "mixtral": Model( + name="mixtral", + model=mixtral_model.MixtralForCasualLM, + config=mixtral_model.MixtralConfig, + source={ + "huggingface-torch": mixtral_loader.huggingface, + "huggingface-safetensor": mixtral_loader.huggingface, + }, + quantize={ + "no-quant": mixtral_quantization.no_quant, + "group-quant": mixtral_quantization.group_quant, + }, + ), "gpt_neox": Model( name="gpt_neox", model=gpt_neox_model.GPTNeoXForCausalLM, diff --git a/python/mlc_chat/nn/__init__.py b/python/mlc_chat/nn/__init__.py index b7884f402c..4c8ff32c69 100644 --- a/python/mlc_chat/nn/__init__.py +++ b/python/mlc_chat/nn/__init__.py @@ -1,2 +1,3 @@ -"""Neural network components for LLM.""" +"""Common `nn.Modules` used to define LLMs in this project.""" +from .expert import MixtralExperts from .kv_cache import FlashInferPagedKVCache, PagedKVCache diff --git a/python/mlc_chat/nn/expert.py b/python/mlc_chat/nn/expert.py new file mode 100644 index 0000000000..e17f74e2ab --- /dev/null +++ b/python/mlc_chat/nn/expert.py @@ -0,0 +1,24 @@ +"""An nn.Module that represents MoE experts""" +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor + +from mlc_chat.op import moe_matmul + + +class MixtralExperts(nn.Module): + """Mixtral experts""" + + def __init__(self, num_local_experts, in_features, out_features): + self.num_local_experts = num_local_experts + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter((num_local_experts, out_features, in_features)) + self.dtype = "float32" + + def forward(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name,missing-docstring + assert x.ndim == 2 + if indptr.ndim == 2: + assert indptr.shape[0] == 1 + return moe_matmul.gemv(x, self.weight, indptr) + assert indptr.ndim == 1 + return moe_matmul.group_gemm(x, self.weight, indptr) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index a5ec853461..1629da82fa 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -2,14 +2,19 @@ from tvm import relax as rx from tvm import tir from tvm.relax.frontend.nn import Object, Tensor - -from ..op.kv_cache import kv_cache_debug_get_kv, kv_cache_transpose_append +from tvm.script import tir as T class PagedKVCache(Object): # pylint: disable=too-few-public-methods """The Paged KV Cache used in LLM batching for efficient attention computation.""" - def attention(self, layer_id: int, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + def attention( # pylint: disable=invalid-name + self, + layer_id: int, + q: Tensor, + k: Tensor, + v: Tensor, + ) -> Tensor: """Compute attention with the given q/k/v data and in-cache k/v data on the specified layer. Rotary position embeddings are applied to k/v within this function. @@ -26,7 +31,13 @@ def attention(self, layer_id: int, q: Tensor, k: Tensor, v: Tensor) -> Tensor: _expr=rx.BlockBuilder.current().emit( rx.call_dps_packed( "vm.builtin.paged_attention_kv_cache_attention", - [self._expr, rx.PrimValue(layer_id), q._expr, k._expr, v._expr], + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + q._expr, + k._expr, + v._expr, + ], out_sinfo=q._expr.struct_info, ) ) @@ -34,11 +45,11 @@ def attention(self, layer_id: int, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # pylint: enable=protected-access -class FlashInferPagedKVCache(PagedKVCache): +class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods """Paged KV cache using FlashInfer (CUDA) kernels.""" - @staticmethod - def create( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments + self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, page_size: tir.Var, @@ -50,7 +61,7 @@ def create( # pylint: disable=too-many-arguments rope_theta: int, dtype: str, name: str = "paged_kv_cache", - ) -> "FlashInferPagedKVCache": + ) -> None: """Create a paged KV cache object with FlashInfer kernels. Parameters @@ -73,46 +84,119 @@ def create( # pylint: disable=too-many-arguments The base of rotary position embedding. """ - bb = rx.BlockBuilder.current() - return PagedKVCache( + bb = rx.BlockBuilder.current() # pylint: disable=invalid-name + args = [ + rx.ShapeExpr([max_batch_size, max_total_seq_len, page_size]), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.op.zeros((), dtype), # type: ignore[arg-type] + bb.add_func( + _kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), + "kv_cache_transpose_append", + ), + rx.extern("paged_kv_cache.attention_kernel_prefill"), + rx.extern("paged_kv_cache.attention_kernel_decode"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), + rx.extern("paged_kv_cache.attention_kernel_prefill_begin_forward"), + rx.extern("paged_kv_cache.attention_kernel_prefill_end_forward"), + rx.extern("paged_kv_cache.attention_kernel_decode_begin_forward"), + rx.extern("paged_kv_cache.attention_kernel_decode_end_forward"), + rx.extern("flashinfer.batch_qk_apply_rotary_in_place"), + rx.extern("flashinfer.merge_state_in_place"), + bb.add_func( + _kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), + "kv_cache_debug_get_kv", + ), + ] + super().__init__( _expr=rx.Call( rx.extern("vm.builtin.paged_attention_kv_cache_create"), - args=[ - rx.ShapeExpr([max_batch_size, max_total_seq_len, page_size]), - rx.PrimValue(num_hidden_layers), - rx.PrimValue(num_attention_heads), - rx.PrimValue(num_key_value_heads), - rx.PrimValue(head_dim), - rx.PrimValue(rope_scale), - rx.PrimValue(rope_theta), - rx.op.zeros((), dtype), - bb.add_func( - kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), - "kv_cache_transpose_append", - ), - rx.extern("paged_kv_cache.attention_kernel_prefill"), - rx.extern("paged_kv_cache.attention_kernel_decode"), - rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), - rx.extern( - "flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward" - ), - rx.extern( - "flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward" - ), - rx.extern("paged_kv_cache.attention_kernel_prefill_begin_forward"), - rx.extern("paged_kv_cache.attention_kernel_prefill_end_forward"), - rx.extern("paged_kv_cache.attention_kernel_decode_begin_forward"), - rx.extern("paged_kv_cache.attention_kernel_decode_end_forward"), - rx.extern("flashinfer.batch_qk_apply_rotary_in_place"), - rx.extern("flashinfer.merge_state_in_place"), - bb.add_func( - kv_cache_debug_get_kv( - num_hidden_layers, num_key_value_heads, head_dim, dtype - ), - "kv_cache_debug_get_kv", - ), - ], + args=args, sinfo_args=[rx.ObjectStructInfo()], ), _name=name, ) + + +# mypy: disable-error-code="attr-defined" +# pylint: disable=too-many-locals + + +def _kv_cache_transpose_append(num_key_value_heads, head_dim, dtype): + """Return the TIR function that appends new k/v data to PagedKVCache.""" + + # pylint: disable=line-too-long,invalid-name + # fmt: off + @T.prim_func + def tir_kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_position_map: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + ntoken = T.SizeVar("ntoken", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype) + k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) + position_map = T.match_buffer(var_position_map, (ntoken,), "int32") + for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf]) + position: T.int32 = position_map[vgpos] # type: ignore + pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf]) + position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] + pages[T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf] = v_data[vgpos, vh, vf] + # fmt: on + # pylint: enable=line-too-long,invalid-name + + return tir_kv_cache_transpose_append + + +def _kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype): + """Return the TIR function that fetches the k/v data on given positions and layer.""" + + # pylint: disable=line-too-long,invalid-name + # fmt: off + @T.prim_func + def tir_kv_cache_debug_get_kv( + var_pages: T.handle, + var_position_map: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + layer_id: T.int64, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + seqlen = T.SizeVar("seqlen", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype) + position_map = T.match_buffer(var_position_map, (seqlen,), "int32") + k_data = T.match_buffer(var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) + v_data = T.match_buffer(var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) + for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): + with T.block("copy0"): + vp, vh, vd = T.axis.remap("SSS", [p, h, d]) + T.reads(position_map[vp], pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd]) + T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) + position: T.int32 = position_map[vp] # type: ignore[name-defined] + k_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd] + v_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd] + # fmt: on + # pylint: enable=line-too-long,invalid-name + + return tir_kv_cache_debug_get_kv diff --git a/python/mlc_chat/op/__init__.py b/python/mlc_chat/op/__init__.py index ea69db0a02..7e1a1c6d9f 100644 --- a/python/mlc_chat/op/__init__.py +++ b/python/mlc_chat/op/__init__.py @@ -1,4 +1,5 @@ """Extern module for compiler.""" +from . import moe_matmul, moe_misc from .attention import attention from .extern import configure, enable, get_store from .gemm import faster_transformer_dequantize_gemm diff --git a/python/mlc_chat/op/attention.py b/python/mlc_chat/op/attention.py index b1e144af43..f737bdc123 100644 --- a/python/mlc_chat/op/attention.py +++ b/python/mlc_chat/op/attention.py @@ -16,11 +16,12 @@ WARN_FLASHINFER_HEAD_DIM = False -def attention( # pylint: disable=invalid-name,too-many-locals +def attention( # pylint: disable=invalid-name,too-many-locals,too-many-statements q: nn.Tensor, k: nn.Tensor, v: nn.Tensor, casual_mask: nn.Tensor, + attn_score_scaling_factor: float = 1.0, ) -> nn.Tensor: """Attention with casual mask. @@ -47,7 +48,7 @@ def attention( # pylint: disable=invalid-name,too-many-locals v = v.repeat(h_q // h_kv, axis=1) q -> [b, h, s, d] k, v -> [b, h, t, d] - attn = q @ k^T / sqrt(d) # [b, h, s, t] + attn = q @ k^T / sqrt(d) * attn_score_scaling_factor # [b, h, s, t] attn = softmax_with_mask(attn, casual_mask, axis=-1) o = attn @ v # [b, h, s, d] o -> [b, s, h * d] @@ -67,13 +68,15 @@ def _fallback(): if h_kv != h_q: k = k.repeat(h_q // h_kv, axis=2) v = v.repeat(h_q // h_kv, axis=2) - q = q.permute_dims([0, 2, 1, 3]) - k = k.permute_dims([0, 2, 1, 3]) - v = v.permute_dims([0, 2, 1, 3]) + q = op.permute_dims(q, [0, 2, 1, 3]) + k = op.permute_dims(k, [0, 2, 1, 3]) + v = op.permute_dims(v, [0, 2, 1, 3]) attn_weights = op.matmul( # [b, h, s, t] q, # [b, h, s, d] - k.permute_dims([0, 1, 3, 2]), # [b, h, d, t] + op.permute_dims(k, [0, 1, 3, 2]), # [b, h, d, t] ) / math.sqrt(d) + if attn_score_scaling_factor != 1.0: + attn_weights = attn_weights * attn_score_scaling_factor dtype = attn_weights.dtype attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(casual_mask) if dtype == "float32": @@ -81,13 +84,14 @@ def _fallback(): else: attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype) output = op.matmul(attn_weights, v) # [b, h, s, d] <= [b, h, s, t] x [b, h, t, d] - output = output.permute_dims([0, 2, 1, 3]) # [b, s, h, d] - output = output.reshape([b, s, h_q * d]) # [b, s, h * d] + output = op.permute_dims(output, [0, 2, 1, 3]) # [b, s, h, d] + output = op.reshape(output, [b, s, h_q * d]) # [b, s, h * d] return output # FlashInfer Implementation if ( _extern.get_store().flashinfer + and attn_score_scaling_factor == 1.0 and q.dtype == "float16" and k.dtype == "float16" and v.dtype == "float16" diff --git a/python/mlc_chat/op/kv_cache.py b/python/mlc_chat/op/kv_cache.py deleted file mode 100644 index 901b7b6fe1..0000000000 --- a/python/mlc_chat/op/kv_cache.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Operators for KV cache manipulations.""" -# pylint: disable=too-many-locals -from tvm.script import tir as T - - -def kv_cache_transpose_append(num_key_value_heads, head_dim, dtype): - """Return the TIR function that appends new k/v data to PagedKVCache.""" - - @T.prim_func - def tir_kv_cache_transpose_append( - var_pages: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - var_position_map: T.handle, - ): - T.func_attr({"tir.noalias": T.bool(True)}) - ntoken = T.SizeVar("ntoken", "int64") - page_size = T.SizeVar("page_size", "int64") - num_pages = T.int64() - - pages = T.match_buffer( - var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype - ) - k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) - v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) - position_map = T.match_buffer(var_position_map, (ntoken,), "int32") - - for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): - with T.block("k_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes( - pages[ - position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf - ] - ) - position: T.int32 = position_map[vgpos] # type: ignore - pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf - ] = k_data[vgpos, vh, vf] - with T.block("v_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes( - pages[ - position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf - ] - ) - position: T.int32 = position_map[vgpos] # type: ignore - pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf - ] = v_data[vgpos, vh, vf] - - return tir_kv_cache_transpose_append - - -def kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype): - """Return the TIR function that fetches the k/v data on given positions and layer.""" - - @T.prim_func - def tir_kv_cache_debug_get_kv( - var_pages: T.handle, - var_position_map: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - layer_id: T.int64, - ): - T.func_attr({"tir.noalias": T.bool(True)}) - seqlen = T.SizeVar("seqlen", "int64") - page_size = T.SizeVar("page_size", "int64") - num_pages = T.int64() - - pages = T.match_buffer( - var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype - ) - position_map = T.match_buffer(var_position_map, (seqlen,), "int32") - k_data = T.match_buffer( - var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype - ) - v_data = T.match_buffer( - var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype - ) - - for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): - with T.block("copy0"): - vp, vh, vd = T.axis.remap("SSS", [p, h, d]) - T.reads( - position_map[vp], - pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd], - ) - T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) - position: T.int32 = position_map[vp] - k_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd - ] - v_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd - ] - - return tir_kv_cache_debug_get_kv diff --git a/python/mlc_chat/op/moe_matmul.py b/python/mlc_chat/op/moe_matmul.py new file mode 100644 index 0000000000..8e63fd820b --- /dev/null +++ b/python/mlc_chat/op/moe_matmul.py @@ -0,0 +1,548 @@ +"""Mixture of Experts operators""" +from tvm import DataType, tir +from tvm.relax.frontend.nn import Tensor, op +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments,too-many-statements + + +def gemv(x: Tensor, w: Tensor, indptr: Tensor) -> Tensor: + """GEMV for project-in (e1-e3) or project-out (e2) in MLP. + + Parameters + ---------- + x : Tensor + For project-in, the input tensor of shape (1, in_features); and for project-out, the input + shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated + experts per token. + + w : Tensor + The weight tensor of shape (local_experts, out_features, in_features), where `local_experts` + is the total number of experts. + + indptr : Tensor + The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the + number of activated experts per token. + + Returns + ------- + out : Tensor + The output tensor of shape (experts_per_tok, out_features), where `experts_per_tok` is the + number of activated experts per token. + """ + (local_experts, out_features, in_features), dtype = w.shape, w.dtype + _, experts_per_tok = indptr.shape + x_leading_dim, _ = x.shape + + def access_x(x, e, j): + return x[0, j] if x_leading_dim == 1 else x[e, j] + + # NOTE: Currently it assumes x.dtype == w.dtype, but the constraint can be relaxed easily. + assert w.shape == [local_experts, out_features, in_features] and w.dtype == dtype + assert x.shape == [x_leading_dim, in_features] and x.dtype == dtype + assert indptr.shape == [1, experts_per_tok] and indptr.dtype == "int32" + assert x_leading_dim in [1, experts_per_tok] + + @T.prim_func(private=True) + def _func( + x: T.Buffer((x_leading_dim, in_features), dtype), + w: T.Buffer((local_experts, out_features, in_features), dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for e in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, e) + T.reads(x[:, :], w[indptr[0, e], :, :], indptr[0, e]) + T.writes(o[e, :]) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), dtype) + o[e, i] += access_x(x, e, j) * w[indptr[0, e], i, j] + + return op.tensor_ir_op( + _func, + "moe_gemv", + args=[x, w, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], dtype), + ) + + +def dequantize_gemv( # pylint: disable=too-many-arguments + x: Tensor, + w: Tensor, + scale: Tensor, + indptr: Tensor, + quantize_dtype: str, + group_size: int, +) -> Tensor: + """GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized. + It needs to be dequantized before the GEMV computation. + + Parameters + ---------- + x : Tensor + For project-in, the input tensor of shape (1, in_features); and for project-out, the input + shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated + experts per token. + + w : Tensor + The quantized weight tensor of shape (local_experts, out_features, in_features // n), + where n is the number of elements per storage dtype, e.g. if the storage dtype is uint32, + and the quantize dtype is int4, then n is 8. + `local_experts` is the total number of experts including activated and non-active ones. + + scale : Tensor + The scale tensor of shape (local_experts, out_features, in_features // group_size), where + `local_experts` is the total number of experts including activated and non-active ones. + + indptr : Tensor + The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the + number of activated experts per token. + + quantize_dtype : str + The quantize dtype of the weight tensor, which is usually int3, int4 or fp8, etc. + + group_size : int + The number of elements in each quantization group, e.g. 32 or 128. + + Returns + ------- + out : Tensor + The output tensor of shape (experts_per_tok, out_features), where `experts_per_tok` is the + number of activated experts per token. + """ + (x_leading_dim, in_features), model_dtype = x.shape, x.dtype + (local_experts, out_features, _), storage_dtype = w.shape, w.dtype + _, experts_per_tok = indptr.shape + quantize_dtype_bits = DataType(quantize_dtype).bits + num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits + num_group = (in_features + group_size - 1) // group_size + num_storage = group_size // num_elem_per_storage * num_group + + def _dequantize(w, s, e, i, j): + tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) + tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype) + w = w[e, i, j // num_elem_per_storage] + s = s[e, i, j // group_size] + shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) + w = tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(model_dtype) + return (w - tir_max_int) * s + + def access_x(x, e, j): + return x[0, j] if x_leading_dim == 1 else x[e, j] + + assert x.shape == [x_leading_dim, in_features] and x.dtype == model_dtype + assert w.shape == [local_experts, out_features, num_storage] and w.dtype == storage_dtype + assert scale.shape == [local_experts, out_features, num_group] and scale.dtype == model_dtype + assert indptr.shape == [1, experts_per_tok] and indptr.dtype == "int32" + assert x_leading_dim in [1, experts_per_tok] + + @T.prim_func(private=True) + def _func( + x: T.Buffer((x_leading_dim, in_features), model_dtype), + w: T.Buffer((local_experts, out_features, num_storage), storage_dtype), + scale: T.Buffer((local_experts, out_features, num_group), model_dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), model_dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for expert_id in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, expert_id) + y = T.alloc_buffer((out_features, in_features), model_dtype) + for i1, i2 in T.grid(out_features, in_features): + with T.block("dequantize"): + i, j = T.axis.remap("SS", [i1, i2]) + y[i, j] = _dequantize(w, scale, indptr[0, e], i, j) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), model_dtype) + o[e, i] += access_x(x, e, j) * y[i, j] + + return op.tensor_ir_op( + _func, + "moe_dequantize_gemv", + args=[x, w, scale, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], model_dtype), + ) + + +def group_gemm(x: Tensor, w: Tensor, indptr: Tensor): # pylint: disable=too-many-statements + """Group GEMM in MoE models. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, in_features), where `batch_size` could be dynamic shape. + + w : Tensor + Weight tensor of shape (num_local_experts, out_features, in_features). + `w[i, :, :]` is the weight matrix for the `i`-th local expert. + + indptr : Tensor + Index pointer tensor of shape (num_local_experts + 1, ). + `x[indptr[a] : indptr[a + 1]]` is the input for the `i`-th local expert. + + Returns + ------- + out : Tensor + Output tensor of shape (batch_size, out_features). + """ + # NOTE: Currently it assumes x.dtype == w.dtype, but the constraint can be relaxed easily. + (num_local_experts, out_features, in_features), dtype = w.shape, w.dtype + + assert x.shape[1:] == [in_features] and x.dtype == dtype + assert indptr.shape == [num_local_experts + 1] and indptr.dtype == "int32" + + Ne, N, K = num_local_experts, out_features, in_features + BLK_M, BLK_N, BLK_K = 8, 128, 32 + TX, TY, CTA_COUNT = 8, 32, 1024 + VEC_X, VEC_W, VEC_O, VEC_DOT = 1, 1, 1, 1 + UNROLL = 64 + STORAGE_ALIGN = False + assert BLK_K % 8 == 0 + tiles_per_row = (N + BLK_N - 1) // BLK_N + zero = tir.const(0, dtype) + + @T.prim_func(private=True) + def _func( # pylint: disable=too-many-statements + var_x: T.handle, + var_w: T.handle, + var_indptr: T.handle, + var_o: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32(is_size_var=True) + X = T.match_buffer(var_x, (B, K), dtype) + W = T.match_buffer(var_w, (Ne, N, K), dtype) + indptr = T.match_buffer(var_indptr, (Ne + 1,), "int32") + O = T.match_buffer(var_o, (B, N), dtype) + + for _bx in T.thread_binding(CTA_COUNT, thread="blockIdx.x"): + with T.block("CTA"): + bx = T.axis.spatial(CTA_COUNT, _bx) + T.reads(indptr[:], X[:, :], W[:, :, :]) + T.writes(O[:, :]) + # pylint: disable=redefined-builtin + sum = T.alloc_buffer((2,), "int32", scope="local") + row = T.alloc_buffer((2,), "int32", scope="local") + cur_e = T.alloc_buffer((1,), "int32", scope="local") + tile_id = T.alloc_buffer((1,), "int32", scope="local") + # pylint: enable=redefined-builtin + sum[0] = 0 + sum[1] = T.ceildiv(indptr[1] - indptr[0], BLK_M) * tiles_per_row + row[0] = 0 + row[1] = indptr[1] - indptr[0] + cur_e[0] = 0 + tile_id[0] = bx + while T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # move to the current group + while sum[1] <= tile_id[0] and cur_e[0] < Ne: + cur_e[0] += 1 + if cur_e[0] < Ne: + e: T.int32 = cur_e[0] + delta: T.int32 = indptr[e + 1] - indptr[e] + sum[0] = sum[1] + sum[1] += T.ceildiv(delta, BLK_M) * tiles_per_row + row[0] = row[1] + row[1] += delta + # sync threads to make sure all threads have the same tile position + T.tvm_storage_sync("shared") + if T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # fetch current tile position + e: T.int32 = cur_e[0] # type: ignore[no-redef] + num_tiles: T.int32 = tile_id[0] - sum[0] + m_offset: T.int32 = BLK_M * T.floordiv(num_tiles, tiles_per_row) + row[0] + n_offset: T.int32 = BLK_N * T.floormod(num_tiles, tiles_per_row) + with T.block("gemm"): + T.reads( + row[1], + X[m_offset : m_offset + BLK_M, :], + W[e, n_offset : n_offset + BLK_N, :], + ) + T.writes(O[m_offset : m_offset + BLK_M, n_offset : n_offset + BLK_N]) + X_tile = T.alloc_buffer((BLK_M, K), dtype, scope="shared") + W_tile = T.alloc_buffer((BLK_N, K), dtype, scope="shared") + O_tile = T.alloc_buffer((BLK_M, BLK_N), dtype, scope="local") + for a0, a1 in T.grid(BLK_M, K): + with T.block("X_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + X_tile[i, j] = T.if_then_else( + m_offset + i < row[1], + X[m_offset + i, j], + zero, + ) + for a0, a1 in T.grid(BLK_N, K): + with T.block("W_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + W_tile[i, j] = T.if_then_else( + n_offset + i < N, + W[e, n_offset + i, j], + zero, + ) + for a0, a1, a2 in T.grid(BLK_M, BLK_N, K): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [a0, a1, a2]) + with T.init(): + O_tile[i, j] = zero + O_tile[i, j] += X_tile[i, k] * W_tile[j, k] + for a0, a1 in T.grid(BLK_M, BLK_N): + with T.block("store"): + i, j = T.axis.remap("SS", [a0, a1]) + if m_offset + i < row[1] and n_offset + j < N: + O[m_offset + i, n_offset + j] = O_tile[i, j] + # move to next tile + tile_id[0] += CTA_COUNT + + def _schedule(): + sch = tir.Schedule(_func) + + def _cooperative_fetch(block, vec_len): + num_loops = len(sch.get_loops(block)) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + ty, tx, _, vec = sch.split( + sch.fuse(*loops), + factors=[TY, TX, None, vec_len], + ) + sch.vectorize(vec) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + if STORAGE_ALIGN: + sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) + return block + + main_block = sch.get_block("compute") + x, y, k = sch.get_loops(main_block) + ty, yi = sch.split(y, [TY, None]) + tx, xi, vec_c = sch.split(x, [TX, None, VEC_DOT]) + ko, ki = sch.split(k, factors=[None, BLK_K]) + sch.reorder(ty, tx, ko, ki, yi, xi, vec_c) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec_c) + if UNROLL > 0: + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=UNROLL) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + l2g = sch.get_block("store") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + _, v = sch.split(sch.get_loops(l2g)[-1], [None, VEC_O]) + sch.vectorize(v) + _cooperative_fetch(sch.get_block("X_shared"), vec_len=VEC_X) + _cooperative_fetch(sch.get_block("W_shared"), vec_len=VEC_W) + sch.decompose_reduction(main_block, ko) + return sch.mod["main"] + + return op.tensor_ir_op( + _schedule(), + "group_gemm", + args=[x, w, indptr], + out=Tensor.placeholder([x.shape[0], out_features], dtype), + ) + + +def dequantize_group_gemm( + x: Tensor, + w: Tensor, + scale: Tensor, + indptr: Tensor, + quantize_dtype: str, + group_size: int, +): + """Group GEMM in MoE models but the weight is quantized. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, in_features), where `batch_size` could be dynamic shape. + + w : Tensor + Weight tensor of shape (num_local_experts, out_features, in_features // n), where n is the + number of elements per storage dtype, e.g. if the storage dtype is uint32, and the quantize + dtype is int4, then n is 8. + + scale : Tensor + The scale tensor of shape (num_local_experts, out_features, in_features // group_size). + + indptr : Tensor + Index pointer tensor of shape (num_local_experts + 1, ). `x[indptr[a] : indptr[a + 1]]` is + the input for the `i`-th local expert. + + group_size : int + The number of elements in each quantization group, e.g. 32 or 128. + + quantize_dtype : str + The quantize dtype of the weight tensor, which is usually int3, int4 or fp8, etc. + + Returns + ------- + out : Tensor + Output tensor of shape (batch_size, out_features). + """ + (_, in_features), model_dtype = x.shape, x.dtype + (num_local_experts, out_features, _), storage_dtype = w.shape, w.dtype + quantize_dtype_bits = DataType(quantize_dtype).bits + num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits + num_group = (in_features + group_size - 1) // group_size + num_storage = group_size // num_elem_per_storage * num_group + + def _dequantize(w, s, e, i, j): + tir_bin_mask = tir.const((1 << quantize_dtype_bits) - 1, storage_dtype) + tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype) + w = w[e, i, j // num_elem_per_storage] + s = s[e, i, j // group_size] + shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) + w = tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(model_dtype) + return (w - tir_max_int) * s + + Ne, N, K = num_local_experts, out_features, in_features + BLK_M, BLK_N, BLK_K = 8, 128, 32 + TX, TY, CTA_COUNT = 8, 32, 1024 + VEC_X, VEC_W, VEC_O, VEC_DOT = 1, 1, 1, 1 + UNROLL = 64 + STORAGE_ALIGN = False + assert BLK_K % 8 == 0 + tiles_per_row = (N + BLK_N - 1) // BLK_N + zero = tir.const(0, model_dtype) + + @T.prim_func(private=True) + def _func( + var_x: T.handle, + w: T.Buffer((Ne, N, num_storage), storage_dtype), + scale: T.Buffer((Ne, N, num_group), model_dtype), + indptr: T.Buffer((Ne + 1), "int32"), + var_o: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32(is_size_var=True) + X = T.match_buffer(var_x, (B, K), model_dtype) + O = T.match_buffer(var_o, (B, N), model_dtype) + for _bx in T.thread_binding(CTA_COUNT, thread="blockIdx.x"): + with T.block("CTA"): + bx = T.axis.spatial(CTA_COUNT, _bx) + T.reads(X[:, :], w[:, :, :], scale[:, :, :], indptr[:]) + T.writes(O[:, :]) + # pylint: disable=redefined-builtin + sum = T.alloc_buffer((2,), "int32", scope="local") + row = T.alloc_buffer((2,), "int32", scope="local") + cur_e = T.alloc_buffer((1,), "int32", scope="local") + tile_id = T.alloc_buffer((1,), "int32", scope="local") + # pylint: enable=redefined-builtin + sum[0] = 0 + sum[1] = T.ceildiv(indptr[1] - indptr[0], BLK_M) * tiles_per_row + row[0] = 0 + row[1] = indptr[1] - indptr[0] + cur_e[0] = 0 + tile_id[0] = bx + while T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # move to the current group + while sum[1] <= tile_id[0] and cur_e[0] < Ne: + cur_e[0] += 1 + if cur_e[0] < Ne: + e: T.int32 = cur_e[0] + delta: T.int32 = indptr[e + 1] - indptr[e] + sum[0] = sum[1] + sum[1] += T.ceildiv(delta, BLK_M) * tiles_per_row + row[0] = row[1] + row[1] += delta + # sync threads to make sure all threads have the same tile position + T.tvm_storage_sync("shared") + if T.tvm_thread_invariant(cur_e[0] < Ne): # pylint: disable=no-member + # fetch current tile position + e: T.int32 = cur_e[0] # type: ignore[no-redef] + num_tiles: T.int32 = tile_id[0] - sum[0] + m_offset: T.int32 = T.floordiv(num_tiles, tiles_per_row) * BLK_M + row[0] + n_offset: T.int32 = T.floormod(num_tiles, tiles_per_row) * BLK_N + with T.block("gemm"): + T.reads( + row[1], + X[m_offset : m_offset + BLK_M, :], + w[e, n_offset : n_offset + BLK_N, :], + scale[e, n_offset : n_offset + BLK_N, :], + ) + T.writes(O[m_offset : m_offset + BLK_M, n_offset : n_offset + BLK_N]) + X_tile = T.alloc_buffer((BLK_M, K), model_dtype, scope="shared") + W_tile = T.alloc_buffer((BLK_N, K), model_dtype, scope="shared") + O_tile = T.alloc_buffer((BLK_M, BLK_N), "float32", scope="local") + for a0, a1 in T.grid(BLK_M, K): + with T.block("X_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + X_tile[i, j] = T.if_then_else( + m_offset + i < row[1], + X[m_offset + i, j], + zero, + ) + for a0, a1 in T.grid(BLK_N, K): + with T.block("W_shared"): + i, j = T.axis.remap("SS", [a0, a1]) + W_tile[i, j] = T.if_then_else( + n_offset + i < N, + _dequantize(w, scale, e, n_offset + i, j), + zero, + ) + for a0, a1, a2 in T.grid(BLK_M, BLK_N, K): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [a0, a1, a2]) + with T.init(): + O_tile[i, j] = zero + O_tile[i, j] += X_tile[i, k] * W_tile[j, k] + for a0, a1 in T.grid(BLK_M, BLK_N): + with T.block("store"): + i, j = T.axis.remap("SS", [a0, a1]) + if m_offset + i < row[1] and n_offset + j < N: + O[m_offset + i, n_offset + j] = O_tile[i, j] + # move to next tile + tile_id[0] += CTA_COUNT + + def _schedule(): + sch = tir.Schedule(_func) + + def _cooperative_fetch(block, vec_len): + num_loops = len(sch.get_loops(block)) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + ty, tx, _, vec = sch.split( + sch.fuse(*loops), + factors=[TY, TX, None, vec_len], + ) + sch.vectorize(vec) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + if STORAGE_ALIGN: + sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) + return block + + main_block = sch.get_block("compute") + x, y, k = sch.get_loops(main_block) + ty, yi = sch.split(y, [TY, None]) + tx, xi, vec_c = sch.split(x, [TX, None, VEC_DOT]) + ko, ki = sch.split(k, factors=[None, BLK_K]) + sch.reorder(ty, tx, ko, ki, yi, xi, vec_c) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec_c) + if UNROLL > 0: + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=UNROLL) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + l2g = sch.get_block("store") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + _, v = sch.split(sch.get_loops(l2g)[-1], [None, VEC_O]) + sch.vectorize(v) + _cooperative_fetch(sch.get_block("X_shared"), vec_len=VEC_X) + _cooperative_fetch(sch.get_block("W_shared"), vec_len=VEC_W) + sch.decompose_reduction(main_block, ko) + return sch.mod["main"] + + return op.tensor_ir_op( + _schedule(), + "dequantize_group_gemm", + args=[x, w, scale, indptr], + out=Tensor.placeholder([x.shape[0], out_features], model_dtype), + ) diff --git a/python/mlc_chat/op/moe_misc.py b/python/mlc_chat/op/moe_misc.py new file mode 100644 index 0000000000..b13dda1c5d --- /dev/null +++ b/python/mlc_chat/op/moe_misc.py @@ -0,0 +1,350 @@ +"""Mixture of Experts operators""" +from functools import reduce +from typing import Tuple, Union + +from tvm import te, tir +from tvm.relax.frontend.nn import Tensor, op +from tvm.script import tir as T +from tvm.target import Target +from tvm.topi.cuda.scan import inclusive_scan +from tvm.topi.cuda.sort import topk as topi_topk + +# mypy: disable-error-code="attr-defined,name-defined" +# pylint: disable=line-too-long,too-many-locals,invalid-name + + +def moe_sum(x: Tensor, dim: int) -> Tensor: + """Compute the sum of the input tensor along the given axis. It is specialized for the MoE + case where `x.ndim == 3` and `x.shape[1] == num_experts_per_tok (which is 2)`. + """ + if x.ndim == 3 and x.shape[1] == 2: + return op.tensor_expr_op( + lambda x: te.compute( + (x.shape[0], x.shape[2]), + lambda i, j: x[i, 0, j] + x[i, 1, j], + name="sum_2", + ), + "sum", + args=[x], + ) + return op.sum(x, axis=dim) + + +def topk(x: Tensor, k: int) -> Tuple[Tensor, Tensor]: + """Top-k operator specialized for MoE usecases. + + Parameters + ---------- + x : Tensor + The input tensor with shape [batch_size, num_local_experts]. + + k : int + The number of top elements to be selected, which is `num_experts_per_tok` in MoE. + + Returns + ------- + values : Tensor + The top-k values with shape [batch_size, k]. + + indices : Tensor + The top-k indices with shape [batch_size, k]. + """ + (batch_size, num_local_experts), dtype = x.shape, x.dtype + index_dtype = "int32" + + TX = 1024 + SCAN_LEN = 2 + + # specialized kernel for top 2 case + @T.prim_func(private=True) + def topk_func( + var_x: T.handle, + var_out: T.handle, + var_out_index: T.handle, + ) -> None: + T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) + batch_size = T.int64() + x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype) + out = T.match_buffer(var_out, (batch_size, SCAN_LEN), dtype) + out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN), index_dtype) + local_top_k = T.alloc_buffer((SCAN_LEN,), dtype=dtype, scope="local") + local_top_k_index = T.alloc_buffer((SCAN_LEN,), dtype=index_dtype, scope="local") + for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"): + for ii in T.thread_binding(0, T.min(batch_size, TX), "threadIdx.x"): + with T.block("top_k"): + vi = T.axis.spatial(batch_size, io * TX + ii) + T.where(io * TX + ii < batch_size) + with T.block("init"): + local_top_k[0] = T.min_value(dtype) + local_top_k_index[0] = 0 + for k in range(num_local_experts): + with T.block("update"): + vk = T.axis.remap("S", [k]) + # N.B. This snippet is specialized for k = 2 + if x[vi, vk] > local_top_k[0]: + local_top_k[1] = local_top_k[0] + local_top_k_index[1] = local_top_k_index[0] + local_top_k[0] = x[vi, vk] + local_top_k_index[0] = vk + elif x[vi, vk] > local_top_k[1]: + local_top_k[1] = x[vi, vk] + local_top_k_index[1] = vk + for j in T.unroll(SCAN_LEN): + with T.block("output"): + vj = T.axis.remap("S", [j]) + out[vi, vj] = local_top_k[vj] + out_index[vi, vj] = local_top_k_index[vj] + + if k == 2: + return op.tensor_ir_op( + topk_func, + "top2", + args=[x], + out=( + Tensor.placeholder([batch_size, 2], dtype), + Tensor.placeholder([batch_size, 2], index_dtype), + ), + ) + return op.tensor_expr_op(topi_topk, "topk", args=[x, k, -1, "both", False, index_dtype]) # type: ignore[list-item] + + +def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor: + """An operator that returns the cumsum array in MoE. + + The input `expert_indices` of shape [batch_size, experts_per_tok] indicates the indices of + the activated experts for each instance in a batch. This operator first converts it to + `expert_mask`, a boolean mask with shape [batch_size, num_local_experts], and then computes + cumsum over the transpose-then-flattened array of `expert_mask`. + + A position `(e, b)` in the result `cumsum`, where `e` is the expert id and `b` is the batch id, + indicates a shuffling plan that moves the `b`-th instance that ensures the inputs to the `e`-th + expert is contiguous. + + Parameters + ---------- + expert_indices : Tensor + The topk indices with shape [batch_size, experts_per_tok], int32, where + `experts_per_tok` is the number of activated experts. + + num_local_experts : int + The number of totally experts. + + Returns + ------- + topk_mask : Tensor + The boolean mask with shape [batch_size, num_local_experts], int32. + + Example + ------- + Suppose `batch_size` is 4, `experts_per_tok` is 2, the total number of experts is 6, and + `expert_indices` is the 2D tensor below: + + [ + [0, 1], + [1, 2], + [3, 4], + [2, 5], + ] + + , then the `expert_mask` is a tensor of shape [batch_size, num_local_experts] below: + + [ + [1, 1, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 1, 0, 0, 1], + ] + + . The result cumsum of the transposed `expert_mask` is a flattened version of 2D tensor below: + + [ + [1, 1, 1, 1], + [2, 3, 3, 3], + [3, 4, 4, 5], + [5, 5, 6, 6], + [6, 6, 7, 7], + [7, 7, 7, 8], + ] + """ + batch_size, experts_per_tok = expert_indices.shape + expert_mask = ( + op.tensor_expr_op( # pylint: disable=too-many-function-args + lambda expert_indices: te.compute( + (batch_size, num_local_experts), + lambda i, j: tir.expr.Select( + reduce( + tir.Or, + [expert_indices[i, k] == j for k in range(experts_per_tok)], + ), + true_value=tir.const(1, "int32"), + false_value=tir.const(0, "int32"), + ), + ), + "expert_mask", + args=[expert_indices], + ) + .permute_dims(1, 0) + .reshape(batch_size * num_local_experts) + ) + with Target( + { + "kind": "cuda", + "max_num_threads": 1024, + "arch": "sm_50", + } + ): + return op.tensor_expr_op(inclusive_scan, "cumsum", args=[expert_mask, 0, "int32"]) # type: ignore[list-item] + + +def get_indices(cumsum: Tensor, expert_indices: Tensor) -> Tensor: + """Returns a 1D tensor of indices that represents the shuffling plan for each instance in a + batch, so that the inputs to each experts are contiguous. + + If `indices[i] = (b, j)`, it means the `b`-th instance in the batch should be moved to the + `i`-th position in shuffling, and `j` doesn not matter only meaning `expert_indices[b, j]` + corresponds to the expert at position `i` in the shuffling plan. + + Effectively it is equivalent to the following Python code: + + .. code-block:: python + + for b in range(batch_size): + for j in range(experts_per_tok): + e = expert_indices[b, j] + indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j + + Parameters + ---------- + cumsum : Tensor + A flattened 1D tensor whose original shape is [experts_per_tok, batch_size]. + + expert_indices : Tensor + The indices of the experts with shape [batch_size, experts_per_tok]. + + Returns + ------- + indices : Tensor + The indices of the experts with shape [batch_size * experts_per_tok]. + """ + TX = 1024 + batch_size, experts_per_tok = expert_indices.shape + + @T.prim_func(private=True) + def _func(var_cumsum: T.handle, var_expert_indices: T.handle, var_indices: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + batch_size = T.SizeVar("batch_size", "int32") + cumsum_len = T.SizeVar("cumsum_len", "int32") # [experts_per_tok * batch_size] + cumsum = T.match_buffer(var_cumsum, [cumsum_len], "int32") + expert_indices = T.match_buffer(var_expert_indices, [batch_size, experts_per_tok], "int32") + indices = T.match_buffer(var_indices, [batch_size * experts_per_tok], "int32") + for bj_o in T.thread_binding(0, T.ceildiv(cumsum_len, TX), "blockIdx.x"): + for bj_i in T.thread_binding(0, TX, "threadIdx.x"): + with T.block("indices"): + T.reads(expert_indices[:, :], cumsum[:]) + T.writes(indices[:]) + if bj_o * TX + bj_i < cumsum_len: + b: T.int32 = T.floordiv(bj_o * TX + bj_i, experts_per_tok) + j: T.int32 = T.floormod(bj_o * TX + bj_i, experts_per_tok) + e: T.int32 = expert_indices[b, j] + indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j + + return op.tensor_ir_op( + _func, + "get_flattened_expert_indices", + args=[cumsum, expert_indices], + out=Tensor.placeholder([batch_size * experts_per_tok], "int32"), + ) + + +def get_indptr(cumsum: Tensor, num_local_experts: int, batch_size: Union[int, tir.Var]) -> Tensor: + """Extract the `indptr` array from MoE cumsum array. The MoE cumsum array is a flattened tensor + whose original shape is [num_local_experts, batch_size], and the `indptr` array is a 1D tensor + of length `num_local_experts + 1`. The range `[indptr[i], indptr[i + 1])` indicates instances in + the batch that corresponds to the `i`-th expert. + + Effectively, this operator is equivalent to the following numpy code: + + .. code-block:: python + + indptr = np.zeros(num_local_experts + 1, dtype=np.int32) + indptr[0] = 0 + for i in range(1, num_local_experts + 1): + indptr[i] = cumsum[i * batch_size - 1] + return indptr + + Parameters + ---------- + cumsum : Tensor + The prefix sum of the sparse array with shape [batch_size * num_local_experts], int32. + + num_local_experts : int + The number of experts. + + batch_size : int | tir.Var + The batch size. Note that the batch size here refers to `batch_size * seq_len` in MoE, + and we name is `batch_size` for simplicity here only because the two dimensions are fused + in Mixtral. + + Returns + ------- + indptr : Tensor + The `indptr` array with shape [num_local_experts + 1], int32. + """ + + @T.prim_func(private=True) + def _func(var_cumsum: T.handle, var_indptr: T.handle, batch_size: T.int32): + T.func_attr({"tir.noalias": True}) + cumsum = T.match_buffer(var_cumsum, shape=[batch_size * num_local_experts], dtype="int32") + indptr = T.match_buffer(var_indptr, shape=[num_local_experts + 1], dtype="int32") + for vi in T.serial(0, num_local_experts + 1): + with T.block("indptr"): + i = T.axis.spatial(num_local_experts + 1, vi) + indptr[i] = T.Select(i > 0, cumsum[i * batch_size - 1], T.int32(0)) + + assert cumsum.ndim == 1 + return op.tensor_ir_op( + _func, + "get_expert_instance_indptr", + args=[cumsum, batch_size], # type: ignore[list-item] + out=Tensor.placeholder([num_local_experts + 1], "int32"), + ) + + +def scatter_output(x: Tensor, indices: Tensor) -> Tensor: + """Scatter the output of MoE experts back to the original positions. + + Parameters + ---------- + x : Tensor + The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size]. + + indices : Tensor + The indices of the experts with shape [batch_size * num_experts_per_tok]. + + Returns + ------- + out : Tensor + The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size]. + """ + dtype = x.dtype + + @T.prim_func(private=True) + def _func(var_x: T.handle, var_indices: T.handle, var_out: T.handle): + T.func_attr({"tir.noalias": True}) + hidden_size = T.int64() + indices_len = T.int64() + x = T.match_buffer(var_x, [indices_len, hidden_size], dtype) + indices = T.match_buffer(var_indices, [indices_len], "int32") + out = T.match_buffer(var_out, [indices_len, hidden_size], dtype) + for i in T.serial(0, indices_len): + for j in T.serial(0, hidden_size): + with T.block("scatter"): + vi, vj = T.axis.remap("SS", [i, j]) + out[indices[vi], vj] = x[vi, vj] + + return op.tensor_ir_op( + _func, + "scatter_output", + args=[x, indices], + out=Tensor.placeholder(x.shape, dtype), + ) diff --git a/python/mlc_chat/quantization/group_quantization.py b/python/mlc_chat/quantization/group_quantization.py index a2b5b7545f..0643f2befe 100644 --- a/python/mlc_chat/quantization/group_quantization.py +++ b/python/mlc_chat/quantization/group_quantization.py @@ -10,6 +10,7 @@ from tvm.target import Target from mlc_chat.loader import QuantizeMapping +from mlc_chat.nn import MixtralExperts from mlc_chat.support import logging from mlc_chat.support import tensor_parallel as tp @@ -110,6 +111,11 @@ def visit_module(self, name: str, node: nn.Module) -> Any: self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] self.quant_map.map_func[weight_name] = self.config.quantize_weight return GroupQuantizeEmbedding.from_embedding(node, self.config) + if isinstance(node, MixtralExperts): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return GroupQuantizeMixtralExperts.from_mixtral_experts(node, self.config) return self.visit(name, node) model.to(dtype=self.model_dtype) @@ -441,6 +447,107 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name ) +class GroupQuantizeMixtralExperts(nn.Module): # pylint: disable=too-many-instance-attributes + """An MixtralExperts module with group quantization""" + + def __init__( + self, + num_local_experts, + in_features, + out_features, + config: GroupQuantize, + ): # pylint: disable=too-many-arguments + self.num_local_experts = num_local_experts + self.in_features = in_features + self.out_features = out_features + self.config = config + num_group = tir.ceildiv(in_features, config.group_size) + self.q_weight = nn.Parameter( + (num_local_experts, out_features, config.num_storage_per_group * num_group), + config.storage_dtype, + ) + self.q_scale = nn.Parameter( + (num_local_experts, out_features, num_group), config.model_dtype + ) + self.quantize_dtype = config.quantize_dtype + self.group_size = config.group_size + self.dtype = config.model_dtype + + @staticmethod + def from_mixtral_experts( + src: "MixtralExperts", config: GroupQuantize + ) -> "GroupQuantizeMixtralExperts": + """ + Converts a non-quantized MixtralExperts to a group quantized GroupQuantizeMixtralExperts + + Parameters + ---------- + src : MixtralExperts + The non-quantized MixtralExperts + + config : GroupQuantize + The group quantization config. + + Returns + ------- + ret : GroupQuantizeMixtralExperts + The group quantized GroupQuantizeMixtralExperts layer. + """ + quantized_mistral_experts = GroupQuantizeMixtralExperts( + num_local_experts=src.num_local_experts, + in_features=src.in_features, + out_features=src.out_features, + config=config, + ) + if "shard_strategy" in src.weight.attrs: + shard = src.weight.attrs["shard_strategy"] + _apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight) + _apply_sharding(shard, f"{shard.name}_q_scale", quantized_mistral_experts.q_scale) + return quantized_mistral_experts + + def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """Forward method for group quantized mistral experts. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + indptr: nn.Tensor + The indptr tensor + + single_batch_decode: bool + Whether to use single-batch decode + + Returns + ------- + ret : nn.Tensor + The output tensor for the group quantized mistral experts layer. + """ + from mlc_chat.op import moe_matmul # pylint: disable=import-outside-toplevel + + assert x.ndim == 2 + if indptr.ndim == 2: # single-batch + assert indptr.shape[0] == 1 + return moe_matmul.dequantize_gemv( + x, + self.q_weight, + self.q_scale, + indptr, + quantize_dtype=self.quantize_dtype, + group_size=self.group_size, + ) + assert indptr.ndim == 1 + return moe_matmul.dequantize_group_gemm( + x, + self.q_weight, + self.q_scale, + indptr, + quantize_dtype=self.quantize_dtype, + group_size=self.group_size, + ) + + def _apply_sharding(shard, name: str, weight: nn.Parameter): if isinstance(shard, tp.ShardSingleDim): weight.attrs["shard_strategy"] = tp.ShardSingleDim( diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py index 7531149bee..39133ffb41 100644 --- a/tests/python/model/test_kv_cache.py +++ b/tests/python/model/test_kv_cache.py @@ -1,5 +1,4 @@ -# pylint: disable=line-too-long,missing-module-docstring,missing-function-docstring -# pylint: disable=unused-argument,missing-class-docstring,too-many-locals,too-many-statements +# pylint: disable=line-too-long,missing-docstring import tvm from tvm import tir from tvm.relax.frontend.nn import core, modules, spec @@ -9,6 +8,9 @@ from mlc_chat.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache +# mypy: disable-error-code="attr-defined" +# pylint: disable=invalid-name,unused-argument,too-many-locals,too-many-statements + def test_nn_module_paged_kv_cache(): # fmt: off @@ -23,13 +25,12 @@ def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, v position_map = T.match_buffer(var_position_map, (seqlen,), "int32") k_data = T.match_buffer(var_k_data, (32, seqlen, 32, 128), "float16") v_data = T.match_buffer(var_v_data, (32, seqlen, 32, 128), "float16") - # with T.block("root"): for p, h, d in T.grid(seqlen, 32, 128): with T.block("copy0"): vp, vh, vd = T.axis.remap("SSS", [p, h, d]) T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd]) T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) - position: T.int32 = position_map[vp] + position: T.int32 = position_map[vp] # type: ignore[name-defined] k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd] v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd] @@ -60,7 +61,7 @@ def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var @R.function def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): - _io: R.Object = R.null_value() + _io: R.Object = R.null_value() # type: ignore lv: R.Tuple(R.Object) = (_io,) # type: ignore gv: R.Tuple(R.Object) = lv # type: ignore R.output(gv) @@ -75,7 +76,7 @@ def create_flashinfer_paged_kv_cache(max_batch_size: R.Shape(["max_batch_size_1" cls = Module with R.dataflow(): lv2: R.Tensor((), dtype="float16") = R.zeros(R.shape([]), dtype="float16") # type: ignore - paged_kv_cache: R.Object = R.call_packed("vm.builtin.paged_attention_kv_cache_create", R.shape([max_batch_size_1, max_total_seq_len_1, page_size_1]), R.prim_value(32), R.prim_value(32), R.prim_value(32), R.prim_value(128), R.prim_value(1), R.prim_value(10000), lv2, cls.tir_kv_cache_transpose_append, R.ExternFunc("paged_kv_cache.attention_kernel_prefill"), R.ExternFunc("paged_kv_cache.attention_kernel_decode"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_end_forward"), R.ExternFunc("flashinfer.batch_qk_apply_rotary_in_place"), R.ExternFunc("flashinfer.merge_state_in_place"), cls.tir_kv_cache_debug_get_kv, sinfo_args=(R.Object,)) + paged_kv_cache: R.Object = R.call_packed("vm.builtin.paged_attention_kv_cache_create", R.shape([max_batch_size_1, max_total_seq_len_1, page_size_1]), R.prim_value(32), R.prim_value(32), R.prim_value(32), R.prim_value(128), R.prim_value(1), R.prim_value(10000), lv2, cls.tir_kv_cache_transpose_append, R.ExternFunc("paged_kv_cache.attention_kernel_prefill"), R.ExternFunc("paged_kv_cache.attention_kernel_decode"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_end_forward"), R.ExternFunc("flashinfer.batch_qk_apply_rotary_in_place"), R.ExternFunc("flashinfer.merge_state_in_place"), cls.tir_kv_cache_debug_get_kv, sinfo_args=(R.Object,)) # type: ignore gv2: R.Tuple(R.Object, R.Tuple(R.Object)) = paged_kv_cache, (_io,) # type: ignore R.output(gv2) return gv2 @@ -103,7 +104,7 @@ def forward( def create_flashinfer_paged_kv_cache( self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, page_size: tir.Var ) -> PagedKVCache: - return FlashInferPagedKVCache.create( + return FlashInferPagedKVCache( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, page_size=page_size,