Skip to content

Commit 1ac89bc

Browse files
committed
...
1 parent 2bbd9d7 commit 1ac89bc

File tree

7 files changed

+852
-852
lines changed

7 files changed

+852
-852
lines changed

python/mlc_chat/model/gpt2/gpt2_model.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
TODO: add docstring
44
"""
55
import dataclasses
6-
import math
76
from typing import Any, Dict, Optional
87

98
from tvm import te, tir
109
from tvm.relax.frontend import nn
1110
from tvm.relax.frontend.nn import Tensor, op
1211

12+
from mlc_chat import op as op_ext
1313
from mlc_chat.support import logging
1414
from mlc_chat.support.config import ConfigBase
1515
from mlc_chat.support.style import bold
@@ -110,29 +110,15 @@ def forward(
110110

111111
self.k_cache.append(op.squeeze(k, axis=0))
112112
self.v_cache.append(op.squeeze(v, axis=0))
113-
k = op.reshape(self.k_cache.view(t), (b, t, h, d))
114-
v = op.reshape(self.v_cache.view(t), (b, t, h, d))
115-
116-
q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d]
117-
k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
118-
v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
119-
120-
attn_weights = op.matmul(
121-
q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t]
122-
) / math.sqrt(d)
113+
k = self.k_cache.view(t)
114+
v = self.v_cache.view(t)
123115

124116
if self.scale_attn_by_inverse_layer_idx:
125-
attn_weights = attn_weights / float(self.layer_idx + 1)
126-
127-
dtype = attn_weights.dtype
128-
attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask)
129-
if dtype == "float32":
130-
attn_weights = op.softmax(attn_weights, axis=-1)
117+
attn_score_scaling_factor = 1.0 / float(self.layer_idx + 1)
131118
else:
132-
attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype)
133-
# [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d]
134-
output = op.matmul(attn_weights, v)
135-
return self.c_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h * d)))
119+
attn_score_scaling_factor = 1.0
120+
output = op_ext.attention(q, k, v, attention_mask, attn_score_scaling_factor)
121+
return self.c_proj(output)
136122

137123

138124
class GPT2MLP(nn.Module):

python/mlc_chat/model/mixtral/mixtral_model.py

Lines changed: 49 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
"""
2-
Implementation for Mistral architecture.
3-
"""
1+
"""Implementation for Mistral architecture."""
42
import dataclasses
53

6-
from tvm import te, tir
4+
from tvm import tir
75
from tvm.relax.frontend import nn
86
from tvm.relax.frontend.nn import Tensor, op
9-
from tvm.topi.cuda.scan import inclusive_scan
107

118
from mlc_chat import op as op_ext
129
from mlc_chat.model.mistral.mistral_model import (
@@ -39,112 +36,81 @@ class MixtralMoE(nn.Module):
3936

4037
def __init__(self, config: MixtralConfig):
4138
super().__init__()
42-
self.gate = nn.Linear(
43-
in_features=config.hidden_size, out_features=config.num_local_experts, bias=False
44-
)
4539
self.num_experts_per_tok = config.num_experts_per_tok
4640
self.num_local_experts = config.num_local_experts
4741
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
42+
self.gate = nn.Linear(
43+
in_features=config.hidden_size,
44+
out_features=config.num_local_experts,
45+
bias=False,
46+
)
4847
self.e1_e3 = MixtralExperts(
4948
self.num_local_experts,
50-
self.num_experts_per_tok,
5149
in_features=config.hidden_size,
5250
out_features=2 * self.intermediate_size,
5351
)
5452
self.e2 = MixtralExperts(
5553
self.num_local_experts,
56-
self.num_experts_per_tok,
5754
in_features=self.intermediate_size,
5855
out_features=config.hidden_size,
5956
)
6057
self.dtype = "float32"
6158

62-
# TODO: replace with cumsum nn op when it's ready
63-
def cumsum(self, data: Tensor, dim: int) -> Tensor:
64-
return op.tensor_expr_op(inclusive_scan, "cumsum", args=[data, dim, "int32"])
65-
66-
def sum(self, x):
67-
# dlight cannot handle too small reduction axis extent
68-
# so we manually transform it into spatial op.
69-
if self.num_experts_per_tok == 2:
70-
71-
def te_add(x):
72-
new_shape = (x.shape[0], x.shape[2])
73-
return te.compute(
74-
new_shape,
75-
lambda i, j: x[i, 0, j] + x[i, 1, j],
76-
name="add",
77-
)
78-
79-
return op.tensor_expr_op(te_add, "topk_mask", args=[x])
80-
return op.sum(x, axis=1)
81-
8259
def forward(self, x: Tensor):
83-
assert x.ndim == 3
84-
input_shape = x.shape
85-
x = op.reshape(x, (input_shape[0] * input_shape[1], input_shape[2]))
86-
num_tokens = input_shape[0] * input_shape[1]
87-
88-
# MoE data preparation
60+
def _expert_forward(x: Tensor, indptr: Tensor):
61+
# x: [num_tokens, hidden_size]
62+
x1_x3 = self.e1_e3(x, indptr)
63+
# x1, x3: [experts_per_tok, intermediate_size]
64+
x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1)
65+
# x:
66+
# - batched: [num_tokens, hidden_size]
67+
# - single: [experts_per_tok, hidden_size]
68+
x = self.e2(op.silu(x1) * x3, indptr)
69+
return x
70+
71+
experts_per_tok = self.num_experts_per_tok # activated experts per token
72+
local_experts = self.num_local_experts # total number of experts
73+
batch_size, seq_len, hidden_size = x.shape
74+
num_tokens = batch_size * seq_len
75+
x = x.reshape(num_tokens, hidden_size)
76+
# gate: [num_tokens, local_experts]
8977
gate: Tensor = self.gate(x)
90-
expert_weights, expert_indices = op_ext.topk(
91-
gate, self.num_experts_per_tok, self.num_local_experts, self.dtype, "int32"
92-
)
78+
# expert_weights: [num_tokens, experts_per_tok]
79+
# expert_indices: [num_tokens, experts_per_tok]
80+
expert_weights, expert_indices = op_ext.moe.topk(gate, experts_per_tok)
9381
expert_weights = op.softmax(expert_weights.astype("float32"), axis=-1).astype(self.dtype)
9482
if num_tokens == 1:
95-
# single batch decode
96-
expert_indices = op.reshape(expert_indices, (self.num_experts_per_tok,))
97-
concat_x1_x3 = self.e1_e3(x, expert_indices, single_batch_decode=True)
98-
x1, x3 = op.split(concat_x1_x3, indices_or_sections=2, axis=-1)
99-
linear_out = self.e2(op.silu(x1) * x3, expert_indices, single_batch_decode=True)
100-
unflattened = op.reshape(
101-
linear_out, (num_tokens, self.num_experts_per_tok, linear_out.shape[-1])
102-
)
83+
# x: [num_tokens * experts_per_tok, hidden_size]
84+
x = _expert_forward(x, expert_indices)
10385
else:
104-
expert_mask = op_ext.topk_mask(
105-
expert_indices, self.num_experts_per_tok, self.num_local_experts
106-
)
107-
mask_T_flattened = op.reshape(
108-
op.permute_dims(expert_mask), (expert_mask.shape[0] * expert_mask.shape[1],)
109-
)
110-
cumsum_colwise_flattened = self.cumsum(mask_T_flattened, dim=0)
111-
flattened_indices = op_ext.get_indices(
112-
cumsum_colwise_flattened, expert_indices, self.num_experts_per_tok
113-
)
114-
indptr = op_ext.get_indptr(cumsum_colwise_flattened, self.num_local_experts)
115-
token_indices = op.divide(
116-
flattened_indices, Tensor.from_const(self.num_experts_per_tok)
117-
)
118-
gathered_x = op.take(x, token_indices, axis=0)
119-
120-
# expert forward begin
121-
concat_x1_x3 = self.e1_e3(gathered_x, indptr)
122-
x1, x3 = op.split(concat_x1_x3, indices_or_sections=2, axis=-1)
123-
linear_out = self.e2(op.silu(x1) * x3, indptr)
124-
# expert forward end
125-
126-
# MoE result post-processing
127-
unpermuted = op_ext.scatter_output(flattened_indices, linear_out, self.dtype)
128-
unflattened = op.reshape(
129-
unpermuted, (num_tokens, self.num_experts_per_tok, unpermuted.shape[1])
130-
)
131-
expert_weights = op.reshape(expert_weights, (num_tokens, self.num_experts_per_tok, 1))
132-
weighted_sum = self.sum(unflattened * expert_weights)
133-
weighted_sum = op.reshape(
134-
weighted_sum, (input_shape[0], input_shape[1], weighted_sum.shape[-1])
135-
)
136-
return weighted_sum
86+
# cumsum: [num_tokens * total_experts]
87+
cumsum = op_ext.moe.moe_cumsum(expert_indices, local_experts)
88+
# indices: [num_tokens * experts_per_tok]
89+
indices = op_ext.moe.get_indices(cumsum, expert_indices)
90+
# indptr: [num_local_experts + 1]
91+
indptr = op_ext.moe.get_indptr(cumsum, local_experts)
92+
# x: [num_tokens * experts_per_tok, hidden_size]
93+
x = op.take(x, indices / experts_per_tok, axis=0)
94+
x = _expert_forward(x, indptr)
95+
x = op_ext.moe.scatter_output(x, indices)
96+
# x: [num_tokens, experts_per_tok, hidden_size]
97+
x = x.reshape(num_tokens, experts_per_tok, hidden_size)
98+
x = x * expert_weights.reshape(num_tokens, experts_per_tok, 1)
99+
# x: [num_tokens, hidden_size]
100+
x = op_ext.moe.moe_sum(x, dim=1)
101+
x = x.reshape(batch_size, seq_len, hidden_size)
102+
return x
137103

138104

139105
class MixtralDecoderLayer(nn.Module):
140106
"""Mixtral decoder layer"""
141107

142108
def __init__(self, config: MixtralConfig, rotary_embedding: RotaryEmbedding):
143-
rms_norm_eps = config.rms_norm_eps
109+
eps = config.rms_norm_eps
144110
self.self_attn = MistralAttention(config, rotary_embedding)
145111
self.moe = MixtralMoE(config)
146-
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
147-
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
112+
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False)
113+
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False)
148114

149115
def _set_tp():
150116
def _set(layer, hint):

python/mlc_chat/nn/expert.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,30 @@
22
from tvm.relax.frontend import nn
33
from tvm.relax.frontend.nn import Tensor
44

5-
from mlc_chat import op as op_ext
5+
from mlc_chat.op import moe_matmul
66

77

88
class MixtralExperts(nn.Module):
99
"""Mixtral experts"""
1010

11-
def __init__(self, num_local_experts, num_experts_per_tok, in_features, out_features):
11+
def __init__(self, num_local_experts, in_features, out_features):
1212
self.num_local_experts = num_local_experts
13-
self.num_experts_per_tok = num_experts_per_tok
1413
self.in_features = in_features
1514
self.out_features = out_features
1615
self.weight = nn.Parameter((num_local_experts, out_features, in_features))
1716
self.dtype = "float32"
1817

19-
def forward( # pylint: disable=missing-function-docstring,invalid-name
20-
self,
21-
x: Tensor,
22-
indptr: Tensor,
23-
single_batch_decode: bool = False,
24-
):
25-
assert x.ndim == 2
26-
if single_batch_decode:
27-
# single-batch decode
28-
assert x.shape[1] == self.in_features
29-
assert indptr.ndim == 1
30-
if x.shape[0] == 1:
31-
return op_ext.gemv_e1_e3(
32-
x,
33-
self.weight,
34-
indptr,
35-
self.in_features,
36-
self.out_features,
37-
self.num_experts_per_tok,
38-
self.num_local_experts,
39-
self.dtype,
40-
)
41-
return op_ext.gemv_e2(
42-
x,
43-
self.weight,
44-
indptr,
45-
self.in_features,
46-
self.out_features,
47-
self.num_experts_per_tok,
48-
self.num_local_experts,
49-
self.dtype,
50-
)
18+
def _forward_single(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name
19+
assert x.ndim == 2 and indptr.ndim == 2
20+
assert indptr.shape[0] == 1
21+
return moe_matmul.gemv(x, self.weight, indptr)
5122

52-
return op_ext.group_gemm(
53-
x,
54-
self.weight,
55-
indptr,
56-
self.in_features,
57-
self.out_features,
58-
self.num_local_experts,
59-
self.dtype,
60-
)
23+
def _forward_batched(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name
24+
assert x.ndim == 2 and indptr.ndim == 1
25+
return moe_matmul.group_gemm(x, self.weight, indptr)
26+
27+
def forward(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name,missing-docstring
28+
assert x.ndim == 2 and indptr.ndim in [1, 2]
29+
if indptr.ndim == 1:
30+
return self._forward_batched(x, indptr)
31+
return self._forward_single(x, indptr)

python/mlc_chat/op/__init__.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,6 @@
11
"""Extern module for compiler."""
2+
from . import moe, moe_matmul
23
from .attention import attention
34
from .extern import configure, enable, get_store
45
from .gemm import faster_transformer_dequantize_gemm
5-
from .moe import (
6-
gemv_e1_e3,
7-
gemv_e2,
8-
get_indices,
9-
get_indptr,
10-
group_dequantize_gemv_e1_e3,
11-
group_dequantize_gemv_e2,
12-
group_dequantize_group_gemm,
13-
group_gemm,
14-
scatter_output,
15-
topk,
16-
topk_mask,
17-
)
186
from .position_embedding import llama_rope

python/mlc_chat/op/attention.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
WARN_FLASHINFER_HEAD_DIM = False
1717

1818

19-
def attention( # pylint: disable=invalid-name,too-many-locals
19+
def attention( # pylint: disable=invalid-name,too-many-locals,too-many-statements
2020
q: nn.Tensor,
2121
k: nn.Tensor,
2222
v: nn.Tensor,
2323
casual_mask: nn.Tensor,
24+
attn_score_scaling_factor: float = 1.0,
2425
) -> nn.Tensor:
2526
"""Attention with casual mask.
2627
@@ -47,7 +48,7 @@ def attention( # pylint: disable=invalid-name,too-many-locals
4748
v = v.repeat(h_q // h_kv, axis=1)
4849
q -> [b, h, s, d]
4950
k, v -> [b, h, t, d]
50-
attn = q @ k^T / sqrt(d) # [b, h, s, t]
51+
attn = q @ k^T / sqrt(d) * attn_score_scaling_factor # [b, h, s, t]
5152
attn = softmax_with_mask(attn, casual_mask, axis=-1)
5253
o = attn @ v # [b, h, s, d]
5354
o -> [b, s, h * d]
@@ -67,27 +68,30 @@ def _fallback():
6768
if h_kv != h_q:
6869
k = k.repeat(h_q // h_kv, axis=2)
6970
v = v.repeat(h_q // h_kv, axis=2)
70-
q = q.permute_dims([0, 2, 1, 3])
71-
k = k.permute_dims([0, 2, 1, 3])
72-
v = v.permute_dims([0, 2, 1, 3])
71+
q = op.permute_dims(q, [0, 2, 1, 3])
72+
k = op.permute_dims(k, [0, 2, 1, 3])
73+
v = op.permute_dims(v, [0, 2, 1, 3])
7374
attn_weights = op.matmul( # [b, h, s, t]
7475
q, # [b, h, s, d]
75-
k.permute_dims([0, 1, 3, 2]), # [b, h, d, t]
76+
op.permute_dims(k, [0, 1, 3, 2]), # [b, h, d, t]
7677
) / math.sqrt(d)
78+
if attn_score_scaling_factor != 1.0:
79+
attn_weights = attn_weights * attn_score_scaling_factor
7780
dtype = attn_weights.dtype
7881
attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(casual_mask)
7982
if dtype == "float32":
8083
attn_weights = op.softmax(attn_weights, axis=-1)
8184
else:
8285
attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype)
8386
output = op.matmul(attn_weights, v) # [b, h, s, d] <= [b, h, s, t] x [b, h, t, d]
84-
output = output.permute_dims([0, 2, 1, 3]) # [b, s, h, d]
85-
output = output.reshape([b, s, h_q * d]) # [b, s, h * d]
87+
output = op.permute_dims(output, [0, 2, 1, 3]) # [b, s, h, d]
88+
output = op.reshape(output, [b, s, h_q * d]) # [b, s, h * d]
8689
return output
8790

8891
# FlashInfer Implementation
8992
if (
9093
_extern.get_store().flashinfer
94+
and attn_score_scaling_factor == 1.0
9195
and q.dtype == "float16"
9296
and k.dtype == "float16"
9397
and v.dtype == "float16"

0 commit comments

Comments
 (0)