Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grid size in Triton decoding kernel #2134

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 34 additions & 38 deletions python/sglang/srt/layers/attention/triton_ops/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def _fwd_kernel_stage1(
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SPLIT_K: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)
split_k_id = tl.program_id(2)

reduce_dtype = Att_Out.dtype.element_ty
cur_kv_head = cur_head // kv_group_num
Expand All @@ -65,22 +66,18 @@ def _fwd_kernel_stage1(
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)

cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len

off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
q = tl.load(Q + off_q).to(reduce_dtype)

offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)

block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
split_k_start = kv_len_per_split * split_k_id
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)

for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
offs_n_new = cur_batch_start_index + offs_n
for start_n in range(split_k_start, split_k_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
mask=offs_n < split_k_end,
other=0,
)
offs_buf_k = (
Expand All @@ -90,7 +87,7 @@ def _fwd_kernel_stage1(
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk),
other=0.0,
).to(reduce_dtype)
att_value = tl.sum(q[None, :] * k, 1)
Expand All @@ -100,7 +97,7 @@ def _fwd_kernel_stage1(
att_value = logit_cap * tanh(att_value / logit_cap)

off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end)


@triton.jit
Expand Down Expand Up @@ -189,11 +186,12 @@ def _decode_att_m_fwd(
logit_cap,
):
BLOCK = 32
SPLIT_K = 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this parameter applicable to various cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it on ShareGPT, 8 is an optimal selection.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That may need some tuning for different situations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref https://fireworks.ai/blog/why-gpus-on-demand

Prompt Lengths(Tokens) Fireworks Latency vLLM Latency
Long prompt (4000 input, 200 output) 2117 ms (at 7.5 QPS) 2877 ms (at 0.348 QPS)
Medium prompt (2000 input, 100 output) 740 ms (at 1.33 QPS) 1509 ms (at 0.663 QPS)
Short prompt (128 input, 4 output) 43.3 ms (at 22.51 QPS) 247 ms (at 4.056 QPS)

May we also tune the Medium prompt and Long prompt cases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I think 4k long prompt have nothing to do with "long," even though the blog defines them as such. In reality, some cases are around 30k-50k.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested the throughput (req/s) for these cases, split=8 is also good.

python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --trust-remote-code --tp 1 --attention-backend triton
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --random-range-ratio 1 --num-prompts 1000
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --random-range-ratio 1 --num-prompts 1000
Prompt Lengths(Tokens) Split = 4 Split = 8 Split = 16 Split = 32
Long prompt (4000 input, 200 output) 5.32 5.34 5.35 5.32
Medium prompt (2000 input, 100 output) 14.14 14.15 14.06 13.94

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The throughput looks good. How about the latency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Median decode latency (ms):

python3 -m sglang.bench_one_batch --batch-size 1 --input 
2000 --output 100 --model meta-llama/Llama-3.1-8B-Instruct --attention-backend triton
python3 -m sglang.bench_one_batch --batch-size 1 --input 
4000 --output 200 --model meta-llama/Llama-3.1-8B-Instruct --attention-backend triton
Prompt Lengths(Tokens) Split = 4 Split = 8 Split = 16 Split = 32
Long prompt (4000 input, 200 output) 9.78 9.37 9.02 8.91
Medium prompt (2000 input, 100 output) 8.33 8.06 7.94 7.94

Lk = k_buffer.shape[-1]

batch, head_num = B_req_idx.shape[0], q.shape[1]

grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
grid = (batch, head_num, SPLIT_K)
kv_group_num = q.shape[1] // k_buffer.shape[1]

if kv_group_num == 1:
Expand Down Expand Up @@ -221,6 +219,7 @@ def _decode_att_m_fwd(
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK,
SPLIT_K=SPLIT_K,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=1,
Expand Down Expand Up @@ -292,13 +291,14 @@ def _fwd_grouped_kernel_stage1(
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
SPLIT_K: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
start_n = tl.program_id(2)
split_k_id = tl.program_id(2)

reduce_dtype = Att_Out.dtype.element_ty

Expand All @@ -315,30 +315,27 @@ def _fwd_grouped_kernel_stage1(
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)

cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len

offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
).to(reduce_dtype)

if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype)

offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
split_k_start = kv_len_per_split * split_k_id
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)

block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)

for start_mark in range(0, block_mask, 1):
q = tl.load(
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
).to(reduce_dtype)
offs_n_new = cur_batch_start_index + offs_n
for start_n in range(split_k_start, split_k_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
mask=offs_n < split_k_end,
other=0,
)
offs_buf_k = (
Expand All @@ -348,22 +345,19 @@ def _fwd_grouped_kernel_stage1(
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
other=0.0,
).to(reduce_dtype)
qk = tl.dot(q, k)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
reduce_dtype
)
offs_buf_kpe = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[None, :] < cur_batch_end_index,
mask=offs_n[None, :] < split_k_end,
other=0.0,
).to(reduce_dtype)
qk += tl.dot(qpe, kpe)
Expand All @@ -379,7 +373,7 @@ def _fwd_grouped_kernel_stage1(
tl.store(
Att_Out + offs_o,
qk,
mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
)


Expand Down Expand Up @@ -497,10 +491,11 @@ def _decode_grouped_att_m_fwd(
kv_group_num = q.shape[1] // k_buffer.shape[1]

BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
SPLIT_K = 8
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
triton.cdiv(max_len_in_batch, BLOCK),
SPLIT_K,
)

num_warps = 4
Expand Down Expand Up @@ -532,6 +527,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=1,
Expand Down
Loading