|
| 1 | +import logging |
| 2 | +import math |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | +from flag_gems.runtime import device, error, torch_device_fn |
| 9 | +from flag_gems.utils import triton_lang_extension as tle |
| 10 | + |
| 11 | +vendor_name = device.vendor_name |
| 12 | +device = device.name |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +# @triton.autotune( |
| 17 | +# configs=[ |
| 18 | +# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s) |
| 19 | +# for h in [32, 64, 128] |
| 20 | +# for n in [32, 64, 128] |
| 21 | +# for w in [4, 8] |
| 22 | +# for s in [1, 2] |
| 23 | +# ], |
| 24 | +# key=["head_num"] |
| 25 | +# ) |
| 26 | +@triton.heuristics( |
| 27 | + values={ |
| 28 | + "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0, |
| 29 | + } |
| 30 | +) |
| 31 | +@triton.jit |
| 32 | +def flash_mla_attn_kernel( |
| 33 | + Q_ptr, |
| 34 | + Kv_cache, |
| 35 | + Req_to_tokens, |
| 36 | + B_seq_len, |
| 37 | + O, |
| 38 | + sm_scale, |
| 39 | + head_num, |
| 40 | + stride_q_bs, |
| 41 | + stride_q_h, |
| 42 | + stride_kv_bs, |
| 43 | + stride_req_to_tokens_bs, |
| 44 | + stride_o_b, |
| 45 | + stride_o_h, |
| 46 | + stride_o_s, |
| 47 | + BLOCK_H: tl.constexpr, |
| 48 | + BLOCK_N: tl.constexpr, |
| 49 | + EVEN_H: tl.constexpr, |
| 50 | + PAGE_SIZE: tl.constexpr, |
| 51 | + HEAD_DIM_V: tl.constexpr, |
| 52 | + HEAD_DIM: tl.constexpr, |
| 53 | +): |
| 54 | + cur_head_id = tle.program_id(0) |
| 55 | + cur_batch_id = tle.program_id(1) |
| 56 | + Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id |
| 57 | + |
| 58 | + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) |
| 59 | + |
| 60 | + offs_d_ckv = tl.arange(0, HEAD_DIM_V) |
| 61 | + offs_q_nope = ( |
| 62 | + cur_batch_id * stride_q_bs |
| 63 | + + cur_head[:, None] * stride_q_h |
| 64 | + + offs_d_ckv[None, :] |
| 65 | + ) |
| 66 | + |
| 67 | + offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) |
| 68 | + offs_q_pe = ( |
| 69 | + cur_batch_id * stride_q_bs |
| 70 | + + cur_head[:, None] * stride_q_h |
| 71 | + + offs_d_kpe[None, :] |
| 72 | + ) |
| 73 | + |
| 74 | + if EVEN_H: |
| 75 | + q_nope = tl.load(Q_ptr + offs_q_nope) |
| 76 | + q_pe = tl.load(Q_ptr + offs_q_pe) |
| 77 | + else: |
| 78 | + mask_head = cur_head < head_num |
| 79 | + q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None]) |
| 80 | + q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None]) |
| 81 | + |
| 82 | + e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32) |
| 83 | + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) |
| 84 | + acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32) |
| 85 | + |
| 86 | + cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id) |
| 87 | + loop_time = cur_batch_seq_len // BLOCK_N |
| 88 | + remainder = cur_batch_seq_len % BLOCK_N |
| 89 | + offs_n = tl.arange(0, BLOCK_N) |
| 90 | + for i in range(0, loop_time): |
| 91 | + kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) |
| 92 | + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE |
| 93 | + offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] |
| 94 | + v_c = tl.load(Kv_cache + offs_v_c) |
| 95 | + k_c = tl.trans(v_c) |
| 96 | + |
| 97 | + qk = tl.dot(q_nope, k_c) # qk_nope |
| 98 | + |
| 99 | + offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] |
| 100 | + k_pe = tl.load(Kv_cache + offs_k_pe) |
| 101 | + |
| 102 | + qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope |
| 103 | + qk *= sm_scale |
| 104 | + |
| 105 | + n_e_max = tl.maximum(tl.max(qk, 1), e_max) |
| 106 | + re_scale = tl.exp(e_max - n_e_max) |
| 107 | + p = tl.exp(qk - n_e_max[:, None]) |
| 108 | + acc *= re_scale[:, None] |
| 109 | + acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) |
| 110 | + |
| 111 | + e_sum = e_sum * re_scale + tl.sum(p, 1) |
| 112 | + e_max = n_e_max |
| 113 | + offs_n += BLOCK_N |
| 114 | + |
| 115 | + if remainder: |
| 116 | + mask_kvsplit = offs_n < cur_batch_seq_len |
| 117 | + kv_page_number = tl.load( |
| 118 | + Req_to_tokens + offs_n // PAGE_SIZE, |
| 119 | + mask=mask_kvsplit, |
| 120 | + other=0, |
| 121 | + ) |
| 122 | + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE |
| 123 | + offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] |
| 124 | + v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0) |
| 125 | + k_c = tl.trans(v_c) |
| 126 | + |
| 127 | + qk = tl.dot(q_nope, k_c) # qk_nope |
| 128 | + |
| 129 | + offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] |
| 130 | + k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0) |
| 131 | + |
| 132 | + qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope |
| 133 | + qk *= sm_scale |
| 134 | + |
| 135 | + qk = tl.where(mask_kvsplit[None, :], qk, float("-inf")) |
| 136 | + |
| 137 | + n_e_max = tl.maximum(tl.max(qk, 1), e_max) |
| 138 | + re_scale = tl.exp(e_max - n_e_max) |
| 139 | + p = tl.exp(qk - n_e_max[:, None]) |
| 140 | + acc *= re_scale[:, None] |
| 141 | + acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) |
| 142 | + |
| 143 | + e_sum = e_sum * re_scale + tl.sum(p, 1) |
| 144 | + |
| 145 | + offs_o = ( |
| 146 | + cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :] |
| 147 | + ) |
| 148 | + if EVEN_H: |
| 149 | + tl.store( |
| 150 | + O + offs_o, |
| 151 | + acc / e_sum[:, None], |
| 152 | + ) |
| 153 | + else: |
| 154 | + tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None]) |
| 155 | + |
| 156 | + |
| 157 | +def flash_mla( |
| 158 | + q, |
| 159 | + block_table, |
| 160 | + blocked_k, |
| 161 | + max_seqlen_pad, |
| 162 | + block_size, |
| 163 | + b, |
| 164 | + s_q, |
| 165 | + cache_seqlens, |
| 166 | + h_q, |
| 167 | + h_kv, |
| 168 | + d, |
| 169 | + dv, |
| 170 | + causal, |
| 171 | +): |
| 172 | + logger.debug("GEMS_CAMBRICON FLASH MLA") |
| 173 | + assert causal, "causal False not supported" |
| 174 | + assert d > dv, "mla with rope dim should be larger than no rope dim" |
| 175 | + |
| 176 | + batch_size, s_q, head_num, d = list(q.shape) |
| 177 | + q = q.view([-1, head_num, d]).contiguous() |
| 178 | + blocked_k = blocked_k.view([-1, d]).contiguous() |
| 179 | + block_table = block_table.contiguous() |
| 180 | + cache_seqlens = cache_seqlens.contiguous() |
| 181 | + |
| 182 | + sm_scale = 1 / math.sqrt(d) |
| 183 | + |
| 184 | + o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device) |
| 185 | + |
| 186 | + major, _ = torch_device_fn.get_device_capability(device) |
| 187 | + if major == 9: |
| 188 | + BLOCK_H = 64 |
| 189 | + num_stages = 3 |
| 190 | + elif major == 8: |
| 191 | + BLOCK_H = 32 |
| 192 | + num_stages = 2 |
| 193 | + elif major == 7 and vendor_name == "iluvatar": |
| 194 | + BLOCK_H = 32 |
| 195 | + num_stages = 1 |
| 196 | + elif vendor_name == "cambricon": |
| 197 | + BLOCK_H = 32 |
| 198 | + num_stages = 1 |
| 199 | + else: |
| 200 | + error.backend_not_support(device) |
| 201 | + BLOCK_N = 64 |
| 202 | + grid = ( |
| 203 | + triton.cdiv(head_num, BLOCK_H), |
| 204 | + batch_size, |
| 205 | + ) |
| 206 | + with torch_device_fn.device(device): |
| 207 | + flash_mla_attn_kernel[grid]( |
| 208 | + q, |
| 209 | + blocked_k, |
| 210 | + block_table, |
| 211 | + cache_seqlens, |
| 212 | + o, |
| 213 | + sm_scale, |
| 214 | + head_num, |
| 215 | + # stride |
| 216 | + q.stride(0), |
| 217 | + q.stride(1), |
| 218 | + blocked_k.stride(-2), |
| 219 | + block_table.stride(0), |
| 220 | + o.stride(0), |
| 221 | + o.stride(1), |
| 222 | + o.stride(2), |
| 223 | + BLOCK_H=BLOCK_H, |
| 224 | + BLOCK_N=BLOCK_N, |
| 225 | + PAGE_SIZE=block_size, |
| 226 | + HEAD_DIM_V=dv, |
| 227 | + HEAD_DIM=d, |
| 228 | + num_warps=8, |
| 229 | + num_stages=num_stages, |
| 230 | + ) |
| 231 | + |
| 232 | + return o.view([b, s_q, h_q, dv]) |
0 commit comments