Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PannenetsF committed Jan 21, 2025
1 parent dceb079 commit fcaab70
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions lightllm/models/llama/triton_kernel/gqa_flash_decoding_bib.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def _kernel_gqa_flash_decoding_bib_stage1(
d_off = tl.arange(0, HEAD_DIM)

cur_batch = tl.load(chunk2batch_tensor + chunk_idx)
cur_req = tl.load(b_req_idx_tensor + cur_batch)
cur_seq_len = tl.load(b_seq_len_tensor + cur_batch)
cur_start = tl.load(chunk2start_tensor + chunk_idx)
cur_end = tl.minimum(cur_start + CHUNK_SIZE, cur_seq_len)
Expand All @@ -133,9 +132,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
max_exp = tl.zeros([Q_GROUP_SIZE], dtype=tl.float32) - float("inf")
accum = tl.zeros([Q_GROUP_SIZE, HEAD_DIM], dtype=tl.float32)

for block_idx in tl.range(0, cur_block_num, 1):
for block_idx in range(0, cur_block_num, 1):
block_range = cur_start + block_idx * BLOCK_N + tl.arange(0, BLOCK_N) # shape [BLOCK_N]
block_mask = block_range < cur_end # shape [BLOCK_N]
cur_req = tl.load(b_req_idx_tensor + cur_batch)
cur_kv_loc = tl.load(
req_to_token_idx_tensor
+ cur_req * req_to_token_idx_stride_bs
Expand All @@ -156,10 +156,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(

exp_logic = tl.exp(att - new_max[:, None])
log_scale = tl.exp(max_exp - new_max)
accum *= log_scale[:, None]

v_off = cur_kv_loc[:, None] * v_stride_token + kv_head_idx * v_stride_h + d_off[None, :] * v_stride_d
v = tl.load(v_tensor + v_off, mask=block_mask[:, None], other=0.0)
accum *= log_scale[:, None]
accum += tl.dot(exp_logic.to(v.dtype), v)

sum_exp = sum_exp * log_scale + tl.sum(exp_logic, axis=1)
Expand All @@ -168,10 +168,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
off_mid_o = (
chunk_idx * mid_o_stride_chunk + cur_q_range[:, None] * mid_o_stride_h + d_off[None, :] * mid_o_stride_d
) # shape [Q_GROUP_SIZE, HEAD_DIM]
tl.store(mid_o_tensor + off_mid_o, accum, mask=cur_q_mask[:, None])
off_mid_o_logexpsum = (
chunk_idx * mid_o_logexpsum_stride_chunk + cur_q_range * mid_o_logexpsum_stride_h
) # shape [Q_GROUP_SIZE, 1]
tl.store(mid_o_tensor + off_mid_o, accum, mask=cur_q_mask[:, None])
tl.store(mid_o_logexpsum_tensor + off_mid_o_logexpsum, sum_exp, mask=cur_q_mask)


Expand All @@ -197,6 +197,7 @@ def gqa_flash_decoding_bib_stage1(
grid size: [chunk_num, kv_head_num]
"""
grid = (chunk_num, k.shape[1])
assert chunk_size >= run_config["BLOCK_N"] and chunk_size % run_config["BLOCK_N"] == 0
_kernel_gqa_flash_decoding_bib_stage1[grid](
q,
k,
Expand Down Expand Up @@ -333,7 +334,7 @@ def gqa_flash_decoding_bib(q, k, v, infer_state, out=None, alloc_tensor_func=tor
out_dtype=q.dtype,
)
if not hasattr(infer_state, "bib_info"):
chunk_size = run_config["BLOCK_N"]
chunk_size = run_config.get("CHUNK_SIZE", run_config["BLOCK_N"])

# TODO: impl in triton
b_seq_len = infer_state.b_seq_len
Expand Down

0 comments on commit fcaab70

Please sign in to comment.