Skip to content

Commit

Permalink
add cuda int4kv copy kernel (#369)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Mar 21, 2024
1 parent 9fed5e9 commit c5d6794
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv(

@torch.no_grad()
def destindex_copy_int4kv(K, DestLoc, Out, Out_scale):
seq_len = DestLoc.shape[0]
head_num = K.shape[1]
# seq_len = DestLoc.shape[0]
# head_num = K.shape[1]
head_dim = K.shape[2]
quant_group_dim = 8

assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv"
grid = (seq_len, head_num)
num_warps = 1
# grid = (seq_len, head_num)
# num_warps = 1

group_size = head_dim // quant_group_dim
group_dim = quant_group_dim
Expand All @@ -90,39 +90,43 @@ def destindex_copy_int4kv(K, DestLoc, Out, Out_scale):
Out.shape[0], Out.shape[1], group_size, group_dim // 2
) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2

_fwd_kernel_destindex_copy_quantize_int4_kv[grid](
K,
DestLoc,
Out,
Out_scale,
K.stride(0),
K.stride(1),
K.stride(2),
K.stride(3),
Out.stride(0),
Out.stride(1),
Out.stride(2),
Out.stride(3),
Out_scale.stride(0),
Out_scale.stride(1),
Out_scale.stride(2),
group_size,
BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),
BLOCK_GROUP_DIM=group_dim,
num_warps=num_warps,
num_stages=1,
)
from lightllm_ppl_int4kv_flashdecoding_kernel import group8_copy_int4_kv

group8_copy_int4_kv(Out, Out_scale, K, DestLoc, 4)

# _fwd_kernel_destindex_copy_quantize_int4_kv[grid](
# K,
# DestLoc,
# Out,
# Out_scale,
# K.stride(0),
# K.stride(1),
# K.stride(2),
# K.stride(3),
# Out.stride(0),
# Out.stride(1),
# Out.stride(2),
# Out.stride(3),
# Out_scale.stride(0),
# Out_scale.stride(1),
# Out_scale.stride(2),
# group_size,
# BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),
# BLOCK_GROUP_DIM=group_dim,
# num_warps=num_warps,
# num_stages=1,
# )
return


def test2():
import time

src = torch.randn((1, 1, 16), dtype=torch.float16).cuda()
src[0, 0, :] = torch.tensor([-2, 1, 2, 0, 4, 5, 6, 7, -2, 1, 2, 0, 4, 5, 6, 7]).cuda()
src = torch.randn((1, 1, 8), dtype=torch.float16).cuda()
src[0, 0, :] = torch.tensor([1, -2, 2, 0, 4, 5, 6, 7]).cuda()
dest_loc = torch.arange(0, 1, dtype=torch.int32).cuda()
value_dest = torch.randn((1, 1, 8), dtype=torch.float16).cuda().to(torch.int8)
scale_dest = torch.randn((1, 1, 2), dtype=torch.float16).cuda()
value_dest = torch.randn((1, 1, 4), dtype=torch.float16).cuda().to(torch.int8)
scale_dest = torch.randn((1, 1, 1), dtype=torch.float16).cuda()

destindex_copy_int4kv(src, dest_loc, value_dest, scale_dest)

Expand Down

0 comments on commit c5d6794

Please sign in to comment.