From f91ef7f536d18999a43297f17f7ca651300d8705 Mon Sep 17 00:00:00 2001 From: yunqian Date: Mon, 13 Jan 2025 12:35:25 +0000 Subject: [PATCH] small improve --- .../triton_kernel/gqa_flash_decoding_vsm.py | 56 +++++++++++-------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py index dbbe0cc17..e35885b7c 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py @@ -44,12 +44,12 @@ def try_to_get_best_config( return config else: config = { - "BLOCK_N": 16, + "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, - "stage2_num_stages": 2, + "stage2_num_stages": 1, } return config @@ -150,38 +150,45 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1( mid_o_logexpsum: [q_head_num, total_seq_block_num] """ sm_id = tl.program_id(0).to(tl.int64) - block_size = tl.load(block_size, eviction_policy="evict_last") + block_size = tl.load(block_size) out_batch_start_index = tl.cast(0, tl.int64) q_head_off = tl.arange(0, Q_HEAD_NUM) d_off = tl.arange(0, BLOCK_DMODEL) - for cur_batch in tl.range(0, batch_size, 1): - cur_req_idx = tl.load(b_req_idx + cur_batch, eviction_policy="evict_last") - cur_seq_len = tl.load(b_seq_len + cur_batch, eviction_policy="evict_last") + for cur_batch in range(0, batch_size): + cur_req_idx = tl.load(b_req_idx + cur_batch) + cur_seq_len = tl.load(b_seq_len + cur_batch) cur_num_of_blocks = tl.cdiv(cur_seq_len, block_size) cur_num_of_kv_head_pairs = cur_num_of_blocks * kv_head_num - loop_sm_id = sm_id - while loop_sm_id < cur_num_of_kv_head_pairs: - cur_block_idx = loop_sm_id // kv_head_num - cur_kv_head_idx = loop_sm_id % kv_head_num + # loop_sm_id = sm_id + while sm_id < cur_num_of_kv_head_pairs: + cur_block_idx = sm_id % cur_num_of_blocks + cur_kv_head_idx = sm_id // cur_num_of_blocks + # cur_block_idx = sm_id // kv_head_num + # cur_kv_head_idx = sm_id % kv_head_num - cur_q_start = cur_kv_head_idx * gqa_group_size - cur_q_range = cur_q_start + q_head_off + cur_q_range = cur_kv_head_idx * gqa_group_size + q_head_off cur_q_mask = q_head_off < gqa_group_size - q_off = cur_batch * stride_q_bs + cur_q_range[:, None] * stride_q_h + d_off[None, :] - q_tensor = tl.load(q + q_off, mask=cur_q_mask[:, None], other=0.0) # shape: [Q_HEAD_NUM, BLOCK_DMODEL] cur_kv_start = cur_block_idx * block_size - cur_kv_end = tl.minimum(cur_kv_start + block_size, cur_seq_len) + + q_off = cur_batch * stride_q_bs + cur_q_range[:, None] * stride_q_h + d_off[None, :] + q_tensor = tl.load( + q + q_off, + mask=cur_q_mask[:, None], + other=0.0, + ) # shape: [Q_HEAD_NUM, BLOCK_DMODEL] sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) max_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") accumu = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) - for chunk_idx in tl.range(0, tl.cdiv(cur_kv_end - cur_kv_start, BLOCK_N), 1, num_stages=NUM_STAGES): + cur_total_chunk = tl.cdiv(tl.minimum(cur_kv_start + block_size, cur_seq_len) - cur_kv_start, BLOCK_N) + + for chunk_idx in tl.range(0, cur_total_chunk, 1, num_stages=NUM_STAGES): cur_chunk_start = cur_kv_start + chunk_idx * BLOCK_N cur_chunk_range = cur_chunk_start + tl.arange(0, BLOCK_N) cur_chunk_mask = cur_chunk_range < cur_seq_len @@ -196,10 +203,10 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1( k_off = ( cur_kv_loc[None, :] * stride_k_bs + cur_kv_head_idx * stride_k_h + d_off[:, None] ) # shape: [BLOCK_DMODEL, BLOCK_N] + v_off = cur_kv_loc[:, None] * stride_v_bs + cur_kv_head_idx * stride_v_h + d_off[None, :] k_tensor = tl.load(k + k_off, mask=cur_chunk_mask[None, :], other=0.0) + att_tensor = tl.dot(q_tensor, k_tensor) # shape: [Q_HEAD_NUM, BLOCK_N] - v_off = cur_kv_loc[:, None] * stride_v_bs + cur_kv_head_idx * stride_v_h + d_off[None, :] - v_tensor = tl.load(v + v_off, mask=cur_chunk_mask[:, None], other=0.0) # shape: [BLOCK_N, BLOCK_DMODEL] att_tensor *= softmax_scale att_tensor = tl.where(cur_chunk_mask[None, :], att_tensor, float("-inf")) @@ -209,7 +216,8 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1( exp_logic = tl.exp(att_tensor - new_max[:, None]) log_scale = tl.exp(max_exp - new_max) accumu *= log_scale[:, None] - accumu += tl.dot(exp_logic, v_tensor.to(accumu.dtype)) + v_tensor = tl.load(v + v_off, mask=cur_chunk_mask[:, None], other=0.0) # shape: [BLOCK_N, BLOCK_DMODEL] + accumu += tl.dot(exp_logic.to(v_tensor.dtype), v_tensor) sum_exp = sum_exp * log_scale + tl.sum(exp_logic, axis=1) max_exp = new_max @@ -223,12 +231,14 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1( cur_q_range * stride_mid_o_logexpsum_h + (out_batch_start_index + cur_block_idx) * stride_mid_o_logexpsum_seq ) + max_exp = max_exp + tl.log(sum_exp) tl.store( mid_o_logexpsum + off_mid_o_logexpsum, - max_exp + tl.log(sum_exp), + max_exp, mask=cur_q_mask, ) - loop_sm_id += num_sm + sm_id += num_sm + sm_id -= cur_num_of_kv_head_pairs out_batch_start_index += cur_num_of_blocks @@ -276,7 +286,7 @@ def gqa_token_decode_attention_flash_decoding_vsm_stage1( *mid_o.stride(), *mid_o_logexpsum.stride(), BLOCK_N=run_config["BLOCK_N"], - Q_HEAD_NUM=max(run_config["BLOCK_Q_HEAD"], triton.next_power_of_2(q_head_num)), + Q_HEAD_NUM=triton.next_power_of_2(gqa_group_size), BLOCK_DMODEL=q.shape[-1], NUM_STAGES=run_config["stage1_num_stages"], num_stages=run_config["stage1_num_stages"], @@ -424,7 +434,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( out_dtype=q.dtype, ) - if not out: + if out is None: out = alloc_tensor_func(q.shape, dtype=q.dtype, device=q.device) num_vsm = emstimate_stage1_vsm(