-
Notifications
You must be signed in to change notification settings - Fork 166
cambricon: update cambricon code based on master #1126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kiddyjinjin
merged 1 commit into
flagos-ai:master
from
chenmiao1919:cambricon_merge_to_master_q4
Dec 8, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,20 @@ | ||
| from .cross_entropy_loss import cross_entropy_loss | ||
| from .flash_mla import flash_mla | ||
| from .fused_add_rms_norm import fused_add_rms_norm | ||
| from .gelu_and_mul import gelu_and_mul | ||
| from .outer import outer | ||
| from .silu_and_mul import silu_and_mul | ||
| from .silu_and_mul import silu_and_mul, silu_and_mul_out | ||
| from .skip_layernorm import skip_layer_norm | ||
| from .weight_norm import weight_norm | ||
|
|
||
| __all__ = [ | ||
| "skip_layer_norm", | ||
| "fused_add_rms_norm", | ||
| "silu_and_mul", | ||
| "silu_and_mul_out", | ||
| "gelu_and_mul", | ||
| "cross_entropy_loss", | ||
| "outer", | ||
| "weight_norm", | ||
| "flash_mla", | ||
| ] |
232 changes: 232 additions & 0 deletions
232
src/flag_gems/runtime/backend/_cambricon/fused/flash_mla.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| import logging | ||
| import math | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from flag_gems.runtime import device, error, torch_device_fn | ||
| from flag_gems.utils import triton_lang_extension as tle | ||
|
|
||
| vendor_name = device.vendor_name | ||
| device = device.name | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| # @triton.autotune( | ||
| # configs=[ | ||
| # triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s) | ||
| # for h in [32, 64, 128] | ||
| # for n in [32, 64, 128] | ||
| # for w in [4, 8] | ||
| # for s in [1, 2] | ||
| # ], | ||
| # key=["head_num"] | ||
| # ) | ||
| @triton.heuristics( | ||
| values={ | ||
| "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0, | ||
| } | ||
| ) | ||
| @triton.jit | ||
| def flash_mla_attn_kernel( | ||
| Q_ptr, | ||
| Kv_cache, | ||
| Req_to_tokens, | ||
| B_seq_len, | ||
| O, | ||
| sm_scale, | ||
| head_num, | ||
| stride_q_bs, | ||
| stride_q_h, | ||
| stride_kv_bs, | ||
| stride_req_to_tokens_bs, | ||
| stride_o_b, | ||
| stride_o_h, | ||
| stride_o_s, | ||
| BLOCK_H: tl.constexpr, | ||
| BLOCK_N: tl.constexpr, | ||
| EVEN_H: tl.constexpr, | ||
| PAGE_SIZE: tl.constexpr, | ||
| HEAD_DIM_V: tl.constexpr, | ||
| HEAD_DIM: tl.constexpr, | ||
| ): | ||
| cur_head_id = tle.program_id(0) | ||
| cur_batch_id = tle.program_id(1) | ||
| Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id | ||
|
|
||
| cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) | ||
|
|
||
| offs_d_ckv = tl.arange(0, HEAD_DIM_V) | ||
| offs_q_nope = ( | ||
| cur_batch_id * stride_q_bs | ||
| + cur_head[:, None] * stride_q_h | ||
| + offs_d_ckv[None, :] | ||
| ) | ||
|
|
||
| offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) | ||
| offs_q_pe = ( | ||
| cur_batch_id * stride_q_bs | ||
| + cur_head[:, None] * stride_q_h | ||
| + offs_d_kpe[None, :] | ||
| ) | ||
|
|
||
| if EVEN_H: | ||
| q_nope = tl.load(Q_ptr + offs_q_nope) | ||
| q_pe = tl.load(Q_ptr + offs_q_pe) | ||
| else: | ||
| mask_head = cur_head < head_num | ||
| q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None]) | ||
| q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None]) | ||
|
|
||
| e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32) | ||
| e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) | ||
| acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32) | ||
|
|
||
| cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id) | ||
| loop_time = cur_batch_seq_len // BLOCK_N | ||
| remainder = cur_batch_seq_len % BLOCK_N | ||
| offs_n = tl.arange(0, BLOCK_N) | ||
| for i in range(0, loop_time): | ||
| kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) | ||
| kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE | ||
| offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] | ||
| v_c = tl.load(Kv_cache + offs_v_c) | ||
| k_c = tl.trans(v_c) | ||
|
|
||
| qk = tl.dot(q_nope, k_c) # qk_nope | ||
|
|
||
| offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] | ||
| k_pe = tl.load(Kv_cache + offs_k_pe) | ||
|
|
||
| qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope | ||
| qk *= sm_scale | ||
|
|
||
| n_e_max = tl.maximum(tl.max(qk, 1), e_max) | ||
| re_scale = tl.exp(e_max - n_e_max) | ||
| p = tl.exp(qk - n_e_max[:, None]) | ||
| acc *= re_scale[:, None] | ||
| acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) | ||
|
|
||
| e_sum = e_sum * re_scale + tl.sum(p, 1) | ||
| e_max = n_e_max | ||
| offs_n += BLOCK_N | ||
|
|
||
| if remainder: | ||
| mask_kvsplit = offs_n < cur_batch_seq_len | ||
| kv_page_number = tl.load( | ||
| Req_to_tokens + offs_n // PAGE_SIZE, | ||
| mask=mask_kvsplit, | ||
| other=0, | ||
| ) | ||
| kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE | ||
| offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] | ||
| v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0) | ||
| k_c = tl.trans(v_c) | ||
|
|
||
| qk = tl.dot(q_nope, k_c) # qk_nope | ||
|
|
||
| offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] | ||
| k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0) | ||
|
|
||
| qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope | ||
| qk *= sm_scale | ||
|
|
||
| qk = tl.where(mask_kvsplit[None, :], qk, float("-inf")) | ||
|
|
||
| n_e_max = tl.maximum(tl.max(qk, 1), e_max) | ||
| re_scale = tl.exp(e_max - n_e_max) | ||
| p = tl.exp(qk - n_e_max[:, None]) | ||
| acc *= re_scale[:, None] | ||
| acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) | ||
|
|
||
| e_sum = e_sum * re_scale + tl.sum(p, 1) | ||
|
|
||
| offs_o = ( | ||
| cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :] | ||
| ) | ||
| if EVEN_H: | ||
| tl.store( | ||
| O + offs_o, | ||
| acc / e_sum[:, None], | ||
| ) | ||
| else: | ||
| tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None]) | ||
|
|
||
|
|
||
| def flash_mla( | ||
| q, | ||
| block_table, | ||
| blocked_k, | ||
| max_seqlen_pad, | ||
| block_size, | ||
| b, | ||
| s_q, | ||
| cache_seqlens, | ||
| h_q, | ||
| h_kv, | ||
| d, | ||
| dv, | ||
| causal, | ||
| ): | ||
| logger.debug("GEMS_CAMBRICON FLASH MLA") | ||
| assert causal, "causal False not supported" | ||
| assert d > dv, "mla with rope dim should be larger than no rope dim" | ||
|
|
||
| batch_size, s_q, head_num, d = list(q.shape) | ||
| q = q.view([-1, head_num, d]).contiguous() | ||
| blocked_k = blocked_k.view([-1, d]).contiguous() | ||
| block_table = block_table.contiguous() | ||
| cache_seqlens = cache_seqlens.contiguous() | ||
|
|
||
| sm_scale = 1 / math.sqrt(d) | ||
|
|
||
| o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device) | ||
|
|
||
| major, _ = torch_device_fn.get_device_capability(device) | ||
| if major == 9: | ||
| BLOCK_H = 64 | ||
| num_stages = 3 | ||
| elif major == 8: | ||
| BLOCK_H = 32 | ||
| num_stages = 2 | ||
| elif major == 7 and vendor_name == "iluvatar": | ||
| BLOCK_H = 32 | ||
| num_stages = 1 | ||
| elif vendor_name == "cambricon": | ||
| BLOCK_H = 32 | ||
| num_stages = 1 | ||
| else: | ||
| error.backend_not_support(device) | ||
| BLOCK_N = 64 | ||
| grid = ( | ||
| triton.cdiv(head_num, BLOCK_H), | ||
| batch_size, | ||
| ) | ||
| with torch_device_fn.device(device): | ||
| flash_mla_attn_kernel[grid]( | ||
| q, | ||
| blocked_k, | ||
| block_table, | ||
| cache_seqlens, | ||
| o, | ||
| sm_scale, | ||
| head_num, | ||
| # stride | ||
| q.stride(0), | ||
| q.stride(1), | ||
| blocked_k.stride(-2), | ||
| block_table.stride(0), | ||
| o.stride(0), | ||
| o.stride(1), | ||
| o.stride(2), | ||
| BLOCK_H=BLOCK_H, | ||
| BLOCK_N=BLOCK_N, | ||
| PAGE_SIZE=block_size, | ||
| HEAD_DIM_V=dv, | ||
| HEAD_DIM=d, | ||
| num_warps=8, | ||
| num_stages=num_stages, | ||
| ) | ||
|
|
||
| return o.view([b, s_q, h_q, dv]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any difference between these two expressions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the randint function in our torch-mlu does not support specific data types (e.g., i16), we’ve made this temporary modification for now. Support for these types will be added in future.