Skip to content

Commit

Permalink
small improve
Browse files Browse the repository at this point in the history
  • Loading branch information
PannenetsF authored and fanyunqian committed Jan 23, 2025
1 parent 5a3edef commit f91ef7f
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def try_to_get_best_config(
return config
else:
config = {
"BLOCK_N": 16,
"BLOCK_N": 64,
"BLOCK_Q_HEAD": 16,
"stage1_num_warps": 4,
"stage1_num_stages": 2,
"stage2_num_warps": 4,
"stage2_num_stages": 2,
"stage2_num_stages": 1,
}
return config

Expand Down Expand Up @@ -150,38 +150,45 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
mid_o_logexpsum: [q_head_num, total_seq_block_num]
"""
sm_id = tl.program_id(0).to(tl.int64)
block_size = tl.load(block_size, eviction_policy="evict_last")
block_size = tl.load(block_size)

out_batch_start_index = tl.cast(0, tl.int64)
q_head_off = tl.arange(0, Q_HEAD_NUM)
d_off = tl.arange(0, BLOCK_DMODEL)

for cur_batch in tl.range(0, batch_size, 1):
cur_req_idx = tl.load(b_req_idx + cur_batch, eviction_policy="evict_last")
cur_seq_len = tl.load(b_seq_len + cur_batch, eviction_policy="evict_last")
for cur_batch in range(0, batch_size):
cur_req_idx = tl.load(b_req_idx + cur_batch)
cur_seq_len = tl.load(b_seq_len + cur_batch)

cur_num_of_blocks = tl.cdiv(cur_seq_len, block_size)
cur_num_of_kv_head_pairs = cur_num_of_blocks * kv_head_num

loop_sm_id = sm_id
while loop_sm_id < cur_num_of_kv_head_pairs:
cur_block_idx = loop_sm_id // kv_head_num
cur_kv_head_idx = loop_sm_id % kv_head_num
# loop_sm_id = sm_id
while sm_id < cur_num_of_kv_head_pairs:
cur_block_idx = sm_id % cur_num_of_blocks
cur_kv_head_idx = sm_id // cur_num_of_blocks
# cur_block_idx = sm_id // kv_head_num
# cur_kv_head_idx = sm_id % kv_head_num

cur_q_start = cur_kv_head_idx * gqa_group_size
cur_q_range = cur_q_start + q_head_off
cur_q_range = cur_kv_head_idx * gqa_group_size + q_head_off
cur_q_mask = q_head_off < gqa_group_size
q_off = cur_batch * stride_q_bs + cur_q_range[:, None] * stride_q_h + d_off[None, :]
q_tensor = tl.load(q + q_off, mask=cur_q_mask[:, None], other=0.0) # shape: [Q_HEAD_NUM, BLOCK_DMODEL]

cur_kv_start = cur_block_idx * block_size
cur_kv_end = tl.minimum(cur_kv_start + block_size, cur_seq_len)

q_off = cur_batch * stride_q_bs + cur_q_range[:, None] * stride_q_h + d_off[None, :]
q_tensor = tl.load(
q + q_off,
mask=cur_q_mask[:, None],
other=0.0,
) # shape: [Q_HEAD_NUM, BLOCK_DMODEL]

sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32)
max_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf")
accumu = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32)

for chunk_idx in tl.range(0, tl.cdiv(cur_kv_end - cur_kv_start, BLOCK_N), 1, num_stages=NUM_STAGES):
cur_total_chunk = tl.cdiv(tl.minimum(cur_kv_start + block_size, cur_seq_len) - cur_kv_start, BLOCK_N)

for chunk_idx in tl.range(0, cur_total_chunk, 1, num_stages=NUM_STAGES):
cur_chunk_start = cur_kv_start + chunk_idx * BLOCK_N
cur_chunk_range = cur_chunk_start + tl.arange(0, BLOCK_N)
cur_chunk_mask = cur_chunk_range < cur_seq_len
Expand All @@ -196,10 +203,10 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
k_off = (
cur_kv_loc[None, :] * stride_k_bs + cur_kv_head_idx * stride_k_h + d_off[:, None]
) # shape: [BLOCK_DMODEL, BLOCK_N]
v_off = cur_kv_loc[:, None] * stride_v_bs + cur_kv_head_idx * stride_v_h + d_off[None, :]
k_tensor = tl.load(k + k_off, mask=cur_chunk_mask[None, :], other=0.0)

att_tensor = tl.dot(q_tensor, k_tensor) # shape: [Q_HEAD_NUM, BLOCK_N]
v_off = cur_kv_loc[:, None] * stride_v_bs + cur_kv_head_idx * stride_v_h + d_off[None, :]
v_tensor = tl.load(v + v_off, mask=cur_chunk_mask[:, None], other=0.0) # shape: [BLOCK_N, BLOCK_DMODEL]
att_tensor *= softmax_scale
att_tensor = tl.where(cur_chunk_mask[None, :], att_tensor, float("-inf"))

Expand All @@ -209,7 +216,8 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
exp_logic = tl.exp(att_tensor - new_max[:, None])
log_scale = tl.exp(max_exp - new_max)
accumu *= log_scale[:, None]
accumu += tl.dot(exp_logic, v_tensor.to(accumu.dtype))
v_tensor = tl.load(v + v_off, mask=cur_chunk_mask[:, None], other=0.0) # shape: [BLOCK_N, BLOCK_DMODEL]
accumu += tl.dot(exp_logic.to(v_tensor.dtype), v_tensor)

sum_exp = sum_exp * log_scale + tl.sum(exp_logic, axis=1)
max_exp = new_max
Expand All @@ -223,12 +231,14 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
cur_q_range * stride_mid_o_logexpsum_h
+ (out_batch_start_index + cur_block_idx) * stride_mid_o_logexpsum_seq
)
max_exp = max_exp + tl.log(sum_exp)
tl.store(
mid_o_logexpsum + off_mid_o_logexpsum,
max_exp + tl.log(sum_exp),
max_exp,
mask=cur_q_mask,
)
loop_sm_id += num_sm
sm_id += num_sm
sm_id -= cur_num_of_kv_head_pairs
out_batch_start_index += cur_num_of_blocks


Expand Down Expand Up @@ -276,7 +286,7 @@ def gqa_token_decode_attention_flash_decoding_vsm_stage1(
*mid_o.stride(),
*mid_o_logexpsum.stride(),
BLOCK_N=run_config["BLOCK_N"],
Q_HEAD_NUM=max(run_config["BLOCK_Q_HEAD"], triton.next_power_of_2(q_head_num)),
Q_HEAD_NUM=triton.next_power_of_2(gqa_group_size),
BLOCK_DMODEL=q.shape[-1],
NUM_STAGES=run_config["stage1_num_stages"],
num_stages=run_config["stage1_num_stages"],
Expand Down Expand Up @@ -424,7 +434,7 @@ def gqa_token_decode_attention_flash_decoding_vsm(
out_dtype=q.dtype,
)

if not out:
if out is None:
out = alloc_tensor_func(q.shape, dtype=q.dtype, device=q.device)

num_vsm = emstimate_stage1_vsm(
Expand Down

0 comments on commit f91ef7f

Please sign in to comment.