Skip to content

Commit

Permalink
add mode ppl_int4kv_flashdecoding (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Mar 20, 2024
1 parent 986f93d commit 859a48b
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 19 deletions.
4 changes: 4 additions & 0 deletions lightllm/common/mem_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.int8kv_mem_manager import INT8KVMemoryManager
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)
Expand All @@ -11,6 +12,9 @@ def select_mem_manager_class(mode):
if "ppl_int8kv" in mode or "ppl_int8kv_flashdecoding" in mode:
memory_manager_class = PPLINT8KVMemoryManager
logger.info(f"Model kv cache using mode {mode}")
elif "ppl_int4kv_flashdecoding" in mode:
memory_manager_class = PPLINT4KVMemoryManager
logger.info(f"Model kv cache using mode {mode}")
elif "triton_int8kv" in mode:
memory_manager_class = INT8KVMemoryManager
logger.info("Model kv cache using mode triton int8kv")
Expand Down
22 changes: 22 additions & 0 deletions lightllm/common/ppl_int4kv_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

from .mem_manager import MemoryManager


class PPLINT4KVMemoryManager(MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
group_quant_size = 8
self.kv_buffer = [
torch.empty((size, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda") for _ in range(layer_num)
]
self.scale_buffer = [
torch.empty((size, 2 * head_num, head_dim // group_quant_size), dtype=dtype, device="cuda")
for _ in range(layer_num)
]

def _free_buffers(self):
self.kv_buffer = None
self.scale_buffer = None
30 changes: 30 additions & 0 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def _bind_attention(self):
LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self
)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self)
elif "ppl_int4kv_flashdecoding" in self.mode:
self._token_attention_kernel = partial(
LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self
)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int4kv, self)
elif "ppl_fp16" in self.mode:
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
Expand Down Expand Up @@ -298,6 +303,14 @@ def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager):
)
return

def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager):
from lightllm.models.llama.triton_kernel.ppl_int4kv_copy_kv import destindex_copy_int4kv

destindex_copy_int4kv(
buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_]
)
return

def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
total_token_num = infer_state.total_token_num
batch_size = infer_state.batch_size
Expand Down Expand Up @@ -524,3 +537,20 @@ def _token_decode_attention_ppl_int8kv_flashdecoding(
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out
)

def _token_decode_attention_ppl_int4kv_flashdecoding(
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
):
from lightllm.models.llama.triton_kernel.ppl_int4kv_flash_decoding import token_decode_attention_flash_decoding

cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out
)
134 changes: 134 additions & 0 deletions lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_destindex_copy_quantize_int4_kv(
K,
Dest_loc,
Out,
Out_scale,
stride_k_bs,
stride_k_h,
stride_k_g,
stride_k_d,
stride_o_bs,
stride_o_h,
stride_o_g,
stride_o_d,
stride_os_bs,
stride_os_h,
stride_os_g,
group_size,
BLOCK_GROUP_NUM: tl.constexpr,
BLOCK_GROUP_DIM: tl.constexpr,
):
cur_index = tl.program_id(0)
cur_head = tl.program_id(1)

offs_g = tl.arange(0, BLOCK_GROUP_NUM)
offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2)

dest_index = tl.load(Dest_loc + cur_index)

src_data_0 = tl.load(
K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2,
mask=offs_g[:, None] < group_size,
other=0.0,
)
src_data_1 = tl.load(
K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2 + 1,
mask=offs_g[:, None] < group_size,
other=0.0,
)

abs_data_0 = tl.abs(src_data_0)
abs_data_1 = tl.abs(src_data_1)

data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(tl.float16)
q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8)
q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0)
q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0)

q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8)
q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1)
q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1)

low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF)
high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4

# tl.device_print(low_4)
# tl.device_print(high_4)

out_data = low_4 | high_4

o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g
tl.store(o_ptrs, out_data, mask=offs_g[:, None] < group_size)
tl.store(os_ptrs, data_scale, mask=offs_g < group_size)
return


@torch.no_grad()
def destindex_copy_int4kv(K, DestLoc, Out, Out_scale):
seq_len = DestLoc.shape[0]
head_num = K.shape[1]
head_dim = K.shape[2]
quant_group_dim = 8

assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv"
grid = (seq_len, head_num)
num_warps = 1

group_size = head_dim // quant_group_dim
group_dim = quant_group_dim

K = K.view((K.shape[0], K.shape[1], group_size, group_dim))
Out = Out.view(
Out.shape[0], Out.shape[1], group_size, group_dim // 2
) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2

_fwd_kernel_destindex_copy_quantize_int4_kv[grid](
K,
DestLoc,
Out,
Out_scale,
K.stride(0),
K.stride(1),
K.stride(2),
K.stride(3),
Out.stride(0),
Out.stride(1),
Out.stride(2),
Out.stride(3),
Out_scale.stride(0),
Out_scale.stride(1),
Out_scale.stride(2),
group_size,
BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),
BLOCK_GROUP_DIM=group_dim,
num_warps=num_warps,
num_stages=1,
)
return


def test2():
import time

src = torch.randn((1, 1, 16), dtype=torch.float16).cuda()
src[0, 0, :] = torch.tensor([-2, 1, 2, 0, 4, 5, 6, 7, -2, 1, 2, 0, 4, 5, 6, 7]).cuda()
dest_loc = torch.arange(0, 1, dtype=torch.int32).cuda()
value_dest = torch.randn((1, 1, 8), dtype=torch.float16).cuda().to(torch.int8)
scale_dest = torch.randn((1, 1, 2), dtype=torch.float16).cuda()

destindex_copy_int4kv(src, dest_loc, value_dest, scale_dest)

print(value_dest)
print(scale_dest)


if __name__ == "__main__":
test2()
44 changes: 44 additions & 0 deletions lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch


def token_decode_attention_flash_decoding(
q, infer_state, q_head_num, head_dim, cache_k, cache_k_scale, cache_v, cache_v_scale, out=None
):
BLOCK_SEQ = 256
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
calcu_shape1 = (batch_size, q_head_num, head_dim)

from lightllm_ppl_int4kv_flashdecoding_kernel import group8_int4kv_flashdecoding_stage1
from .flash_decoding_stage2 import flash_decode_stage2

o_tensor = torch.empty_like(q) if out is None else out

if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda"
)
infer_state.mid_o_logexpsum = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda"
)

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
group8_int4kv_flashdecoding_stage1(
BLOCK_SEQ,
mid_o,
mid_o_logexpsum,
1.0 / (head_dim ** 0.5),
q.view(calcu_shape1),
cache_k,
cache_k_scale,
cache_v,
cache_v_scale,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
return o_tensor
63 changes: 44 additions & 19 deletions lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,46 @@

@triton.jit
def _fwd_kernel_destindex_copy_quantize_kv(
K, Dest_loc, Out, Out_scale,
stride_k_bs, stride_k_h, stride_k_g, stride_k_d,
stride_o_bs, stride_o_h, stride_o_g, stride_o_d,
stride_os_bs, stride_os_h, stride_os_g,
K,
Dest_loc,
Out,
Out_scale,
stride_k_bs,
stride_k_h,
stride_k_g,
stride_k_d,
stride_o_bs,
stride_o_h,
stride_o_g,
stride_o_d,
stride_os_bs,
stride_os_h,
stride_os_g,
group_size,
BLOCK_GROUP_NUM: tl.constexpr,
BLOCK_GROUP_DIM: tl.constexpr
BLOCK_GROUP_DIM: tl.constexpr,
):
cur_index = tl.program_id(0)
cur_head = tl.program_id(1)

offs_g = tl.arange(0, BLOCK_GROUP_NUM)
offs_d = tl.arange(0, BLOCK_GROUP_DIM)

dest_index = tl.load(Dest_loc + cur_index)

src_data = tl.load(K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :],
mask=offs_g[:, None] < group_size, other=0.0)
src_data = tl.load(
K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :],
mask=offs_g[:, None] < group_size,
other=0.0,
)
abs_data = tl.abs(src_data)
data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(tl.float16)
q_src_data = (src_data / data_scale[:, None]).to(tl.int8)
o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]

o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g
tl.store(o_ptrs, q_src_data, mask=offs_g[:, None]<group_size)
tl.store(os_ptrs, data_scale)
tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size)
tl.store(os_ptrs, data_scale, mask=offs_g < group_size)
return


Expand All @@ -53,13 +67,24 @@ def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):
Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim)

_fwd_kernel_destindex_copy_quantize_kv[grid](
K, DestLoc, Out, Out_scale,
K.stride(0), K.stride(1), K.stride(2), K.stride(3),
Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),
Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),
K,
DestLoc,
Out,
Out_scale,
K.stride(0),
K.stride(1),
K.stride(2),
K.stride(3),
Out.stride(0),
Out.stride(1),
Out.stride(2),
Out.stride(3),
Out_scale.stride(0),
Out_scale.stride(1),
Out_scale.stride(2),
group_size,
BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),
BLOCK_GROUP_DIM=group_dim,
BLOCK_GROUP_DIM=group_dim,
num_warps=num_warps,
num_stages=1,
)
Expand Down Expand Up @@ -93,5 +118,5 @@ def test2():
print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32)))


if __name__ == '__main__':
if __name__ == "__main__":
test2()

0 comments on commit 859a48b

Please sign in to comment.