Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion benchmark/test_attention_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def set_more_shapes(self):

@pytest.mark.skipif(vendor_name == "kunlunxin", reason="RESULT TODOFIX")
@pytest.mark.skipif(vendor_name == "hygon", reason="RuntimeError")
@pytest.mark.skipif(flag_gems.vendor_name == "cambricon", reason="TypeError")
@pytest.mark.flash_mla
def test_perf_flash_mla():
def flash_mla_kwargs(shape, dtype, device):
Expand Down
1 change: 0 additions & 1 deletion benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def diagonal_backward_input_fn(shape, dtype, device):
vendor_name == "kunlunxin" and SkipVersion("torch", "<2.5"),
reason="only support torch >= 2.5.",
)
@pytest.mark.skipif(vendor_name == "cambricon", reason="TODOFIX")
@pytest.mark.kron
def test_perf_kron():
class KronBenchmark(GenericBenchmark2DOnly):
Expand Down
4 changes: 2 additions & 2 deletions benchmark/test_unary_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def set_more_shapes(self):
def get_input_iter(self, cur_dtype) -> Generator:
for shape in self.shapes:
inp1 = generate_tensor_input(shape, cur_dtype, self.device)
shift_amount = torch.randint(
0, 8, shape, dtype=cur_dtype, device=self.device
shift_amount = torch.randint(0, 8, shape, dtype=cur_dtype, device="cpu").to(
self.device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any difference between these two expressions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the randint function in our torch-mlu does not support specific data types (e.g., i16), we’ve made this temporary modification for now. Support for these types will be added in future.

)
yield inp1, shift_amount

Expand Down
5 changes: 1 addition & 4 deletions src/flag_gems/runtime/backend/_cambricon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@
)

CUSTOMIZED_UNUSED_OPS = (
"randperm", # skip now
"sort", # skip now
"multinomial", # skip now
"_upsample_bicubic2d_aa", # skip now
"sort_stable",
"copy_",
)

__all__ = ["*"]
5 changes: 4 additions & 1 deletion src/flag_gems/runtime/backend/_cambricon/fused/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from .cross_entropy_loss import cross_entropy_loss
from .flash_mla import flash_mla
from .fused_add_rms_norm import fused_add_rms_norm
from .gelu_and_mul import gelu_and_mul
from .outer import outer
from .silu_and_mul import silu_and_mul
from .silu_and_mul import silu_and_mul, silu_and_mul_out
from .skip_layernorm import skip_layer_norm
from .weight_norm import weight_norm

__all__ = [
"skip_layer_norm",
"fused_add_rms_norm",
"silu_and_mul",
"silu_and_mul_out",
"gelu_and_mul",
"cross_entropy_loss",
"outer",
"weight_norm",
"flash_mla",
]
232 changes: 232 additions & 0 deletions src/flag_gems/runtime/backend/_cambricon/fused/flash_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import logging
import math

import torch
import triton
import triton.language as tl

from flag_gems.runtime import device, error, torch_device_fn
from flag_gems.utils import triton_lang_extension as tle

vendor_name = device.vendor_name
device = device.name
logger = logging.getLogger(__name__)


# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s)
# for h in [32, 64, 128]
# for n in [32, 64, 128]
# for w in [4, 8]
# for s in [1, 2]
# ],
# key=["head_num"]
# )
@triton.heuristics(
values={
"EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0,
}
)
@triton.jit
def flash_mla_attn_kernel(
Q_ptr,
Kv_cache,
Req_to_tokens,
B_seq_len,
O,
sm_scale,
head_num,
stride_q_bs,
stride_q_h,
stride_kv_bs,
stride_req_to_tokens_bs,
stride_o_b,
stride_o_h,
stride_o_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_H: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM_V: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_head_id = tle.program_id(0)
cur_batch_id = tle.program_id(1)
Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id

cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)

offs_d_ckv = tl.arange(0, HEAD_DIM_V)
offs_q_nope = (
cur_batch_id * stride_q_bs
+ cur_head[:, None] * stride_q_h
+ offs_d_ckv[None, :]
)

offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM)
offs_q_pe = (
cur_batch_id * stride_q_bs
+ cur_head[:, None] * stride_q_h
+ offs_d_kpe[None, :]
)

if EVEN_H:
q_nope = tl.load(Q_ptr + offs_q_nope)
q_pe = tl.load(Q_ptr + offs_q_pe)
else:
mask_head = cur_head < head_num
q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None])
q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None])

e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32)
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32)

cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id)
loop_time = cur_batch_seq_len // BLOCK_N
remainder = cur_batch_seq_len % BLOCK_N
offs_n = tl.arange(0, BLOCK_N)
for i in range(0, loop_time):
kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
v_c = tl.load(Kv_cache + offs_v_c)
k_c = tl.trans(v_c)

qk = tl.dot(q_nope, k_c) # qk_nope

offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
k_pe = tl.load(Kv_cache + offs_k_pe)

qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
qk *= sm_scale

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)

e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_n += BLOCK_N

if remainder:
mask_kvsplit = offs_n < cur_batch_seq_len
kv_page_number = tl.load(
Req_to_tokens + offs_n // PAGE_SIZE,
mask=mask_kvsplit,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0)
k_c = tl.trans(v_c)

qk = tl.dot(q_nope, k_c) # qk_nope

offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0)

qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
qk *= sm_scale

qk = tl.where(mask_kvsplit[None, :], qk, float("-inf"))

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)

e_sum = e_sum * re_scale + tl.sum(p, 1)

offs_o = (
cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :]
)
if EVEN_H:
tl.store(
O + offs_o,
acc / e_sum[:, None],
)
else:
tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None])


def flash_mla(
q,
block_table,
blocked_k,
max_seqlen_pad,
block_size,
b,
s_q,
cache_seqlens,
h_q,
h_kv,
d,
dv,
causal,
):
logger.debug("GEMS_CAMBRICON FLASH MLA")
assert causal, "causal False not supported"
assert d > dv, "mla with rope dim should be larger than no rope dim"

batch_size, s_q, head_num, d = list(q.shape)
q = q.view([-1, head_num, d]).contiguous()
blocked_k = blocked_k.view([-1, d]).contiguous()
block_table = block_table.contiguous()
cache_seqlens = cache_seqlens.contiguous()

sm_scale = 1 / math.sqrt(d)

o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device)

major, _ = torch_device_fn.get_device_capability(device)
if major == 9:
BLOCK_H = 64
num_stages = 3
elif major == 8:
BLOCK_H = 32
num_stages = 2
elif major == 7 and vendor_name == "iluvatar":
BLOCK_H = 32
num_stages = 1
elif vendor_name == "cambricon":
BLOCK_H = 32
num_stages = 1
else:
error.backend_not_support(device)
BLOCK_N = 64
grid = (
triton.cdiv(head_num, BLOCK_H),
batch_size,
)
with torch_device_fn.device(device):
flash_mla_attn_kernel[grid](
q,
blocked_k,
block_table,
cache_seqlens,
o,
sm_scale,
head_num,
# stride
q.stride(0),
q.stride(1),
blocked_k.stride(-2),
block_table.stride(0),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
PAGE_SIZE=block_size,
HEAD_DIM_V=dv,
HEAD_DIM=d,
num_warps=8,
num_stages=num_stages,
)

return o.view([b, s_q, h_q, dv])
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
Both `x` and `residual` tensors will be modified. Use with caution if these tensors
are reused elsewhere or require gradients.
"""
logger.debug("GEMS_CAMBRICON FUSED_ADD_RMSNORM FORWARD")
logger.debug(
"GEMS_CAMBRICON FUSED_ADD_RMS_NORM FORWARD, [input shape]: %s, [residual shape]: %s, [weight shape]: %s",
x.size(),
residual.size(),
weight.size(),
)
dim = x.ndim - len(normalized_shape)
M = math.prod(x.shape[:dim])
N = math.prod(normalized_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ def backward(ctx, grad_output):

def silu_and_mul(A, B):
return SiluAndMul.apply(A, B)


def silu_and_mul_out(A, B, out):
silu_and_mul_kernel(A, B, out0=out)
return out
Loading
Loading