Skip to content

Commit

Permalink
udpate mla decode attention. (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Jan 4, 2025
1 parent 02effd7 commit 6ede09e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 39 deletions.
11 changes: 10 additions & 1 deletion lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import copy
from lightllm.utils.log_utils import init_logger
from lightllm.distributed import custom_comm_ops

Expand Down Expand Up @@ -27,10 +28,18 @@ def capture_decode(self, decode_func, input_ids, infer_state):
infer_state.max_len_in_batch = self.graph_max_len_in_batch
infer_state.total_token_num = self.graph_max_len_in_batch * batch_size
# warmup
# 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上
# 做一些初始化的操作,后续层可以复用这些计算的结果,如
# lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py
# 中做的一些操作,所以在 warmup 的时候,需要调用infer_state的copy函数做一个
# 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性,
# 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象
# 中的 tensor。
for _ in range(1):
torch.cuda.synchronize()
decode_func(input_ids, infer_state)
decode_func(input_ids, copy.copy(infer_state)) # infer_state must copy()
torch.cuda.synchronize()

with custom_comm_ops.lightllm_capture_graph():
with torch.cuda.graph(graph_obj, pool=self.mempool):
predict_logics = decode_func(input_ids, infer_state)
Expand Down
78 changes: 61 additions & 17 deletions lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import torch
import torch.multiprocessing as mp
import triton
import triton.language as tl
from typing import List
from lightllm.utils.log_utils import init_logger
from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig
Expand Down Expand Up @@ -44,27 +46,19 @@ def gqa_token_decode_attention_flash_decoding(
out_dtype=torch.bfloat16,
)

BLOCK_N = run_config["BLOCK_N"]

from .gqa_flash_decoding_stage1 import flash_decode_stage1
from .gqa_flash_decoding_stage2 import flash_decode_stage2

o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out

mid_o_block_seq = torch.empty([1], dtype=torch.int64, device="cuda")
mid_o_batch_start_index = alloc_tensor_func(
[
batch_size,
],
dtype=torch.int64,
device="cuda",
)

fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device="cuda")
mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda")
mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda")

vsm_count = flash_decode_stage1(
infer_state.total_token_num_tensor,
mid_o_block_seq,
mid_o_batch_start_index,
fake_decode_att_block_seq,
q_nope.view(calcu_shape1),
q_rope.view(calcu_shape2),
kv_nope,
Expand All @@ -79,13 +73,40 @@ def gqa_token_decode_attention_flash_decoding(
**run_config
)

if not hasattr(infer_state, "decode_att_block_seq"):
assert batch_size <= 2048
decode_att_block_seq = torch.empty(
[
1,
],
dtype=torch.int64,
device="cuda",
)
mid_o_batch_start_index = torch.empty(
[
batch_size,
],
dtype=torch.int64,
device="cuda",
)
_fwd_kernel_calcu_index_and_block_seq[(1,)](
infer_state.b_seq_len,
decode_att_block_seq,
mid_o_batch_start_index,
vsm_count,
batch_size,
BLOCK_N=BLOCK_N,
num_warps=4,
)

infer_state.decode_att_block_seq = decode_att_block_seq
infer_state.mid_o_batch_start_index = mid_o_batch_start_index

mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda")
mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda")

flash_decode_stage1(
infer_state.total_token_num_tensor,
mid_o_block_seq,
mid_o_batch_start_index,
infer_state.decode_att_block_seq,
q_nope.view(calcu_shape1),
q_rope.view(calcu_shape2),
kv_nope,
Expand All @@ -101,12 +122,35 @@ def gqa_token_decode_attention_flash_decoding(
)

flash_decode_stage2(
mid_o_block_seq,
mid_o_batch_start_index,
infer_state.decode_att_block_seq,
infer_state.mid_o_batch_start_index,
mid_o,
mid_o_logexpsum,
infer_state.b_seq_len,
o_tensor.view(calcu_shape1),
**run_config
)
return o_tensor


@triton.jit
def _fwd_kernel_calcu_index_and_block_seq(
b_seq_len_ptr,
mid_o_decode_att_block_seq_ptr,
mid_o_batch_start_index_ptr,
num_sm,
batch_size,
BLOCK_N: tl.constexpr,
):
b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0)
total_token_num = tl.sum(b_seq_len)

block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1
block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N

block_seq_len = tl.cdiv(b_seq_len, block_seq)
cumsum_seq_len = tl.cumsum(block_seq_len)
batch_start_index = cumsum_seq_len - block_seq_len
tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size)
tl.store(mid_o_decode_att_block_seq_ptr, block_seq)
return
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def _fwd_kernel_flash_decode_stage1_padding(
stride_mid_od,
stride_mid_o_eh,
stride_mid_o_es,
total_token_ptr,
block_size_ptr,
batch_start_index_ptr,
num_sm,
head_group_num,
head_num,
Expand All @@ -51,15 +49,9 @@ def _fwd_kernel_flash_decode_stage1_padding(
):
# cur_kv_head = 0
sm_id = tl.program_id(0).to(tl.int64)
grid_id = sm_id
out_batch_start_index = tl.cast(0, tl.int64)
total_token_num = tl.load(total_token_ptr, eviction_policy="evict_last")
block_seq = tl.load(block_size_ptr, eviction_policy="evict_last")

block_seq = tl.cast(total_token_num / num_sm / 4, dtype=tl.int32) + 1
block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N

if grid_id == 0:
tl.store(block_size_ptr, block_seq)
cur_q_head_offs = tl.arange(0, Q_HEAD_NUM)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)
Expand Down Expand Up @@ -163,19 +155,14 @@ def _fwd_kernel_flash_decode_stage1_padding(
)
sm_id += num_sm

if grid_id == 0:
tl.store(batch_start_index_ptr + cur_batch, out_batch_start_index)

out_batch_start_index += cur_block_num // head_group_num
sm_id -= cur_block_num
return


@torch.no_grad()
def flash_decode_stage1(
total_token_num_tensor: torch.Tensor,
out_block_seq: torch.Tensor,
batch_start_index: torch.Tensor,
in_block_seq: torch.Tensor,
q_nope,
q_rope,
kv_nope,
Expand Down Expand Up @@ -227,9 +214,7 @@ def flash_decode_stage1(
*kv_rope.stride(),
*mid_out.stride(),
*mid_out_logsumexp.stride(),
total_token_num_tensor,
out_block_seq,
batch_start_index,
in_block_seq,
num_sm=1,
head_group_num=head_group_num,
head_num=q_head_num,
Expand Down Expand Up @@ -271,9 +256,7 @@ def flash_decode_stage1(
*kv_rope.stride(),
*mid_out.stride(),
*mid_out_logsumexp.stride(),
total_token_num_tensor,
out_block_seq,
batch_start_index,
in_block_seq,
num_sm=num_sm,
head_group_num=head_group_num,
head_num=q_head_num,
Expand Down

0 comments on commit 6ede09e

Please sign in to comment.