|
1 | | -""" |
2 | | -Implementation for Mistral architecture. |
3 | | -""" |
| 1 | +"""Implementation for Mistral architecture.""" |
4 | 2 | import dataclasses |
5 | 3 |
|
6 | | -from tvm import te, tir |
| 4 | +from tvm import tir |
7 | 5 | from tvm.relax.frontend import nn |
8 | 6 | from tvm.relax.frontend.nn import Tensor, op |
9 | | -from tvm.topi.cuda.scan import inclusive_scan |
10 | 7 |
|
11 | 8 | from mlc_chat import op as op_ext |
12 | 9 | from mlc_chat.model.mistral.mistral_model import ( |
@@ -39,112 +36,81 @@ class MixtralMoE(nn.Module): |
39 | 36 |
|
40 | 37 | def __init__(self, config: MixtralConfig): |
41 | 38 | super().__init__() |
42 | | - self.gate = nn.Linear( |
43 | | - in_features=config.hidden_size, out_features=config.num_local_experts, bias=False |
44 | | - ) |
45 | 39 | self.num_experts_per_tok = config.num_experts_per_tok |
46 | 40 | self.num_local_experts = config.num_local_experts |
47 | 41 | self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards |
| 42 | + self.gate = nn.Linear( |
| 43 | + in_features=config.hidden_size, |
| 44 | + out_features=config.num_local_experts, |
| 45 | + bias=False, |
| 46 | + ) |
48 | 47 | self.e1_e3 = MixtralExperts( |
49 | 48 | self.num_local_experts, |
50 | | - self.num_experts_per_tok, |
51 | 49 | in_features=config.hidden_size, |
52 | 50 | out_features=2 * self.intermediate_size, |
53 | 51 | ) |
54 | 52 | self.e2 = MixtralExperts( |
55 | 53 | self.num_local_experts, |
56 | | - self.num_experts_per_tok, |
57 | 54 | in_features=self.intermediate_size, |
58 | 55 | out_features=config.hidden_size, |
59 | 56 | ) |
60 | 57 | self.dtype = "float32" |
61 | 58 |
|
62 | | - # TODO: replace with cumsum nn op when it's ready |
63 | | - def cumsum(self, data: Tensor, dim: int) -> Tensor: |
64 | | - return op.tensor_expr_op(inclusive_scan, "cumsum", args=[data, dim, "int32"]) |
65 | | - |
66 | | - def sum(self, x): |
67 | | - # dlight cannot handle too small reduction axis extent |
68 | | - # so we manually transform it into spatial op. |
69 | | - if self.num_experts_per_tok == 2: |
70 | | - |
71 | | - def te_add(x): |
72 | | - new_shape = (x.shape[0], x.shape[2]) |
73 | | - return te.compute( |
74 | | - new_shape, |
75 | | - lambda i, j: x[i, 0, j] + x[i, 1, j], |
76 | | - name="add", |
77 | | - ) |
78 | | - |
79 | | - return op.tensor_expr_op(te_add, "topk_mask", args=[x]) |
80 | | - return op.sum(x, axis=1) |
81 | | - |
82 | 59 | def forward(self, x: Tensor): |
83 | | - assert x.ndim == 3 |
84 | | - input_shape = x.shape |
85 | | - x = op.reshape(x, (input_shape[0] * input_shape[1], input_shape[2])) |
86 | | - num_tokens = input_shape[0] * input_shape[1] |
87 | | - |
88 | | - # MoE data preparation |
| 60 | + def _expert_forward(x: Tensor, indptr: Tensor): |
| 61 | + # x: [num_tokens, hidden_size] |
| 62 | + x1_x3 = self.e1_e3(x, indptr) |
| 63 | + # x1, x3: [experts_per_tok, intermediate_size] |
| 64 | + x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1) |
| 65 | + # x: |
| 66 | + # - batched: [num_tokens, hidden_size] |
| 67 | + # - single: [experts_per_tok, hidden_size] |
| 68 | + x = self.e2(op.silu(x1) * x3, indptr) |
| 69 | + return x |
| 70 | + |
| 71 | + experts_per_tok = self.num_experts_per_tok # activated experts per token |
| 72 | + local_experts = self.num_local_experts # total number of experts |
| 73 | + batch_size, seq_len, hidden_size = x.shape |
| 74 | + num_tokens = batch_size * seq_len |
| 75 | + x = x.reshape(num_tokens, hidden_size) |
| 76 | + # gate: [num_tokens, local_experts] |
89 | 77 | gate: Tensor = self.gate(x) |
90 | | - expert_weights, expert_indices = op_ext.topk( |
91 | | - gate, self.num_experts_per_tok, self.num_local_experts, self.dtype, "int32" |
92 | | - ) |
| 78 | + # expert_weights: [num_tokens, experts_per_tok] |
| 79 | + # expert_indices: [num_tokens, experts_per_tok] |
| 80 | + expert_weights, expert_indices = op_ext.moe.topk(gate, experts_per_tok) |
93 | 81 | expert_weights = op.softmax(expert_weights.astype("float32"), axis=-1).astype(self.dtype) |
94 | 82 | if num_tokens == 1: |
95 | | - # single batch decode |
96 | | - expert_indices = op.reshape(expert_indices, (self.num_experts_per_tok,)) |
97 | | - concat_x1_x3 = self.e1_e3(x, expert_indices, single_batch_decode=True) |
98 | | - x1, x3 = op.split(concat_x1_x3, indices_or_sections=2, axis=-1) |
99 | | - linear_out = self.e2(op.silu(x1) * x3, expert_indices, single_batch_decode=True) |
100 | | - unflattened = op.reshape( |
101 | | - linear_out, (num_tokens, self.num_experts_per_tok, linear_out.shape[-1]) |
102 | | - ) |
| 83 | + # x: [num_tokens * experts_per_tok, hidden_size] |
| 84 | + x = _expert_forward(x, expert_indices) |
103 | 85 | else: |
104 | | - expert_mask = op_ext.topk_mask( |
105 | | - expert_indices, self.num_experts_per_tok, self.num_local_experts |
106 | | - ) |
107 | | - mask_T_flattened = op.reshape( |
108 | | - op.permute_dims(expert_mask), (expert_mask.shape[0] * expert_mask.shape[1],) |
109 | | - ) |
110 | | - cumsum_colwise_flattened = self.cumsum(mask_T_flattened, dim=0) |
111 | | - flattened_indices = op_ext.get_indices( |
112 | | - cumsum_colwise_flattened, expert_indices, self.num_experts_per_tok |
113 | | - ) |
114 | | - indptr = op_ext.get_indptr(cumsum_colwise_flattened, self.num_local_experts) |
115 | | - token_indices = op.divide( |
116 | | - flattened_indices, Tensor.from_const(self.num_experts_per_tok) |
117 | | - ) |
118 | | - gathered_x = op.take(x, token_indices, axis=0) |
119 | | - |
120 | | - # expert forward begin |
121 | | - concat_x1_x3 = self.e1_e3(gathered_x, indptr) |
122 | | - x1, x3 = op.split(concat_x1_x3, indices_or_sections=2, axis=-1) |
123 | | - linear_out = self.e2(op.silu(x1) * x3, indptr) |
124 | | - # expert forward end |
125 | | - |
126 | | - # MoE result post-processing |
127 | | - unpermuted = op_ext.scatter_output(flattened_indices, linear_out, self.dtype) |
128 | | - unflattened = op.reshape( |
129 | | - unpermuted, (num_tokens, self.num_experts_per_tok, unpermuted.shape[1]) |
130 | | - ) |
131 | | - expert_weights = op.reshape(expert_weights, (num_tokens, self.num_experts_per_tok, 1)) |
132 | | - weighted_sum = self.sum(unflattened * expert_weights) |
133 | | - weighted_sum = op.reshape( |
134 | | - weighted_sum, (input_shape[0], input_shape[1], weighted_sum.shape[-1]) |
135 | | - ) |
136 | | - return weighted_sum |
| 86 | + # cumsum: [num_tokens * total_experts] |
| 87 | + cumsum = op_ext.moe.moe_cumsum(expert_indices, local_experts) |
| 88 | + # indices: [num_tokens * experts_per_tok] |
| 89 | + indices = op_ext.moe.get_indices(cumsum, expert_indices) |
| 90 | + # indptr: [num_local_experts + 1] |
| 91 | + indptr = op_ext.moe.get_indptr(cumsum, local_experts) |
| 92 | + # x: [num_tokens * experts_per_tok, hidden_size] |
| 93 | + x = op.take(x, indices / experts_per_tok, axis=0) |
| 94 | + x = _expert_forward(x, indptr) |
| 95 | + x = op_ext.moe.scatter_output(x, indices) |
| 96 | + # x: [num_tokens, experts_per_tok, hidden_size] |
| 97 | + x = x.reshape(num_tokens, experts_per_tok, hidden_size) |
| 98 | + x = x * expert_weights.reshape(num_tokens, experts_per_tok, 1) |
| 99 | + # x: [num_tokens, hidden_size] |
| 100 | + x = op_ext.moe.moe_sum(x, dim=1) |
| 101 | + x = x.reshape(batch_size, seq_len, hidden_size) |
| 102 | + return x |
137 | 103 |
|
138 | 104 |
|
139 | 105 | class MixtralDecoderLayer(nn.Module): |
140 | 106 | """Mixtral decoder layer""" |
141 | 107 |
|
142 | 108 | def __init__(self, config: MixtralConfig, rotary_embedding: RotaryEmbedding): |
143 | | - rms_norm_eps = config.rms_norm_eps |
| 109 | + eps = config.rms_norm_eps |
144 | 110 | self.self_attn = MistralAttention(config, rotary_embedding) |
145 | 111 | self.moe = MixtralMoE(config) |
146 | | - self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) |
147 | | - self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) |
| 112 | + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False) |
| 113 | + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False) |
148 | 114 |
|
149 | 115 | def _set_tp(): |
150 | 116 | def _set(layer, hint): |
|
0 commit comments