Skip to content

Commit 20ca84b

Browse files
committed
cambricon: update cambricon code based on master
1 parent b32a8f2 commit 20ca84b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+4431
-511
lines changed

benchmark/test_attention_perf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def set_more_shapes(self):
7676

7777
@pytest.mark.skipif(vendor_name == "kunlunxin", reason="RESULT TODOFIX")
7878
@pytest.mark.skipif(vendor_name == "hygon", reason="RuntimeError")
79-
@pytest.mark.skipif(flag_gems.vendor_name == "cambricon", reason="TypeError")
8079
@pytest.mark.flash_mla
8180
def test_perf_flash_mla():
8281
def flash_mla_kwargs(shape, dtype, device):

benchmark/test_special_perf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ def diagonal_backward_input_fn(shape, dtype, device):
390390

391391

392392
@pytest.mark.skipif(vendor_name == "kunlunxin", reason="RESULT TODOFIX")
393-
@pytest.mark.skipif(vendor_name == "cambricon", reason="TODOFIX")
394393
@pytest.mark.kron
395394
def test_perf_kron():
396395
class KronBenchmark(GenericBenchmark2DOnly):

benchmark/test_unary_pointwise_perf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ def set_more_shapes(self):
265265
def get_input_iter(self, cur_dtype) -> Generator:
266266
for shape in self.shapes:
267267
inp1 = generate_tensor_input(shape, cur_dtype, self.device)
268-
shift_amount = torch.randint(
269-
0, 8, shape, dtype=cur_dtype, device=self.device
268+
shift_amount = torch.randint(0, 8, shape, dtype=cur_dtype, device="cpu").to(
269+
self.device
270270
)
271271
yield inp1, shift_amount
272272

src/flag_gems/runtime/backend/_cambricon/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@
2525
)
2626

2727
CUSTOMIZED_UNUSED_OPS = (
28-
"randperm", # skip now
2928
"sort", # skip now
30-
"multinomial", # skip now
3129
"_upsample_bicubic2d_aa", # skip now
3230
"sort_stable",
31+
"copy_",
3332
)
3433

3534
__all__ = ["*"]
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from .cross_entropy_loss import cross_entropy_loss
2+
from .flash_mla import flash_mla
23
from .fused_add_rms_norm import fused_add_rms_norm
34
from .gelu_and_mul import gelu_and_mul
45
from .outer import outer
5-
from .silu_and_mul import silu_and_mul
6+
from .silu_and_mul import silu_and_mul, silu_and_mul_out
67
from .skip_layernorm import skip_layer_norm
78
from .weight_norm import weight_norm
89

910
__all__ = [
1011
"skip_layer_norm",
1112
"fused_add_rms_norm",
1213
"silu_and_mul",
14+
"silu_and_mul_out",
1315
"gelu_and_mul",
1416
"cross_entropy_loss",
1517
"outer",
1618
"weight_norm",
19+
"flash_mla",
1720
]
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import logging
2+
import math
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
8+
from flag_gems.runtime import device, error, torch_device_fn
9+
from flag_gems.utils import triton_lang_extension as tle
10+
11+
vendor_name = device.vendor_name
12+
device = device.name
13+
logger = logging.getLogger(__name__)
14+
15+
16+
# @triton.autotune(
17+
# configs=[
18+
# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s)
19+
# for h in [32, 64, 128]
20+
# for n in [32, 64, 128]
21+
# for w in [4, 8]
22+
# for s in [1, 2]
23+
# ],
24+
# key=["head_num"]
25+
# )
26+
@triton.heuristics(
27+
values={
28+
"EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0,
29+
}
30+
)
31+
@triton.jit
32+
def flash_mla_attn_kernel(
33+
Q_ptr,
34+
Kv_cache,
35+
Req_to_tokens,
36+
B_seq_len,
37+
O,
38+
sm_scale,
39+
head_num,
40+
stride_q_bs,
41+
stride_q_h,
42+
stride_kv_bs,
43+
stride_req_to_tokens_bs,
44+
stride_o_b,
45+
stride_o_h,
46+
stride_o_s,
47+
BLOCK_H: tl.constexpr,
48+
BLOCK_N: tl.constexpr,
49+
EVEN_H: tl.constexpr,
50+
PAGE_SIZE: tl.constexpr,
51+
HEAD_DIM_V: tl.constexpr,
52+
HEAD_DIM: tl.constexpr,
53+
):
54+
cur_head_id = tle.program_id(0)
55+
cur_batch_id = tle.program_id(1)
56+
Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id
57+
58+
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
59+
60+
offs_d_ckv = tl.arange(0, HEAD_DIM_V)
61+
offs_q_nope = (
62+
cur_batch_id * stride_q_bs
63+
+ cur_head[:, None] * stride_q_h
64+
+ offs_d_ckv[None, :]
65+
)
66+
67+
offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM)
68+
offs_q_pe = (
69+
cur_batch_id * stride_q_bs
70+
+ cur_head[:, None] * stride_q_h
71+
+ offs_d_kpe[None, :]
72+
)
73+
74+
if EVEN_H:
75+
q_nope = tl.load(Q_ptr + offs_q_nope)
76+
q_pe = tl.load(Q_ptr + offs_q_pe)
77+
else:
78+
mask_head = cur_head < head_num
79+
q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None])
80+
q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None])
81+
82+
e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32)
83+
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
84+
acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32)
85+
86+
cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id)
87+
loop_time = cur_batch_seq_len // BLOCK_N
88+
remainder = cur_batch_seq_len % BLOCK_N
89+
offs_n = tl.arange(0, BLOCK_N)
90+
for i in range(0, loop_time):
91+
kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE)
92+
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
93+
offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
94+
v_c = tl.load(Kv_cache + offs_v_c)
95+
k_c = tl.trans(v_c)
96+
97+
qk = tl.dot(q_nope, k_c) # qk_nope
98+
99+
offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
100+
k_pe = tl.load(Kv_cache + offs_k_pe)
101+
102+
qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
103+
qk *= sm_scale
104+
105+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
106+
re_scale = tl.exp(e_max - n_e_max)
107+
p = tl.exp(qk - n_e_max[:, None])
108+
acc *= re_scale[:, None]
109+
acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
110+
111+
e_sum = e_sum * re_scale + tl.sum(p, 1)
112+
e_max = n_e_max
113+
offs_n += BLOCK_N
114+
115+
if remainder:
116+
mask_kvsplit = offs_n < cur_batch_seq_len
117+
kv_page_number = tl.load(
118+
Req_to_tokens + offs_n // PAGE_SIZE,
119+
mask=mask_kvsplit,
120+
other=0,
121+
)
122+
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
123+
offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
124+
v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0)
125+
k_c = tl.trans(v_c)
126+
127+
qk = tl.dot(q_nope, k_c) # qk_nope
128+
129+
offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
130+
k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0)
131+
132+
qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
133+
qk *= sm_scale
134+
135+
qk = tl.where(mask_kvsplit[None, :], qk, float("-inf"))
136+
137+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
138+
re_scale = tl.exp(e_max - n_e_max)
139+
p = tl.exp(qk - n_e_max[:, None])
140+
acc *= re_scale[:, None]
141+
acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
142+
143+
e_sum = e_sum * re_scale + tl.sum(p, 1)
144+
145+
offs_o = (
146+
cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :]
147+
)
148+
if EVEN_H:
149+
tl.store(
150+
O + offs_o,
151+
acc / e_sum[:, None],
152+
)
153+
else:
154+
tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None])
155+
156+
157+
def flash_mla(
158+
q,
159+
block_table,
160+
blocked_k,
161+
max_seqlen_pad,
162+
block_size,
163+
b,
164+
s_q,
165+
cache_seqlens,
166+
h_q,
167+
h_kv,
168+
d,
169+
dv,
170+
causal,
171+
):
172+
logger.debug("GEMS_CAMBRICON FLASH MLA")
173+
assert causal, "causal False not supported"
174+
assert d > dv, "mla with rope dim should be larger than no rope dim"
175+
176+
batch_size, s_q, head_num, d = list(q.shape)
177+
q = q.view([-1, head_num, d]).contiguous()
178+
blocked_k = blocked_k.view([-1, d]).contiguous()
179+
block_table = block_table.contiguous()
180+
cache_seqlens = cache_seqlens.contiguous()
181+
182+
sm_scale = 1 / math.sqrt(d)
183+
184+
o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device)
185+
186+
major, _ = torch_device_fn.get_device_capability(device)
187+
if major == 9:
188+
BLOCK_H = 64
189+
num_stages = 3
190+
elif major == 8:
191+
BLOCK_H = 32
192+
num_stages = 2
193+
elif major == 7 and vendor_name == "iluvatar":
194+
BLOCK_H = 32
195+
num_stages = 1
196+
elif vendor_name == "cambricon":
197+
BLOCK_H = 32
198+
num_stages = 1
199+
else:
200+
error.backend_not_support(device)
201+
BLOCK_N = 64
202+
grid = (
203+
triton.cdiv(head_num, BLOCK_H),
204+
batch_size,
205+
)
206+
with torch_device_fn.device(device):
207+
flash_mla_attn_kernel[grid](
208+
q,
209+
blocked_k,
210+
block_table,
211+
cache_seqlens,
212+
o,
213+
sm_scale,
214+
head_num,
215+
# stride
216+
q.stride(0),
217+
q.stride(1),
218+
blocked_k.stride(-2),
219+
block_table.stride(0),
220+
o.stride(0),
221+
o.stride(1),
222+
o.stride(2),
223+
BLOCK_H=BLOCK_H,
224+
BLOCK_N=BLOCK_N,
225+
PAGE_SIZE=block_size,
226+
HEAD_DIM_V=dv,
227+
HEAD_DIM=d,
228+
num_warps=8,
229+
num_stages=num_stages,
230+
)
231+
232+
return o.view([b, s_q, h_q, dv])

src/flag_gems/runtime/backend/_cambricon/fused/fused_add_rms_norm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
9898
Both `x` and `residual` tensors will be modified. Use with caution if these tensors
9999
are reused elsewhere or require gradients.
100100
"""
101-
logger.debug("GEMS_CAMBRICON FUSED_ADD_RMSNORM FORWARD")
101+
logger.debug(
102+
"GEMS_CAMBRICON FUSED_ADD_RMS_NORM FORWARD, [input shape]: %s, [residual shape]: %s, [weight shape]: %s",
103+
x.size(),
104+
residual.size(),
105+
weight.size(),
106+
)
102107
dim = x.ndim - len(normalized_shape)
103108
M = math.prod(x.shape[:dim])
104109
N = math.prod(normalized_shape)

src/flag_gems/runtime/backend/_cambricon/fused/silu_and_mul.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,8 @@ def backward(ctx, grad_output):
4646

4747
def silu_and_mul(A, B):
4848
return SiluAndMul.apply(A, B)
49+
50+
51+
def silu_and_mul_out(A, B, out):
52+
silu_and_mul_kernel(A, B, out0=out)
53+
return out

0 commit comments

Comments
 (0)