Skip to content

Commit

Permalink
[Update] KV Merge And FFn up gate merge. (#320)
Browse files Browse the repository at this point in the history
make inference faster.

---------

Co-authored-by: wxd000000 <[email protected]>
Co-authored-by: wxd000000 <[email protected]>
Co-authored-by: FlyingLaird <[email protected]>
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
5 people authored Jan 25, 2024
1 parent 4a9824b commit 00aeedc
Show file tree
Hide file tree
Showing 46 changed files with 2,114 additions and 1,243 deletions.
10 changes: 5 additions & 5 deletions docs/AddNewModel_CN.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl):
alpha=1.0, out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_))
return q

def _context_attention_kernel(self, q, k, v, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
def _context_attention_kernel(self, q, kv, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
o_tensor = torch.empty_like(q)
context_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_),
k.view(-1, self.tp_k_head_num_, self.head_dim_),
v.view(-1, self.tp_v_head_num_, self.head_dim_),
kv[:, 0: self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
layer_weight.tp_alibi,
infer_state.b_start_loc,
Expand All @@ -457,8 +457,8 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl):
def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
o_tensor = torch.empty_like(q)
token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.mem_manager.key_buffer[self.layer_num_],
infer_state.mem_manager.value_buffer[self.layer_num_],
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :],
infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
layer_weight.tp_alibi,
infer_state.b_loc,
Expand Down
9 changes: 3 additions & 6 deletions lightllm/common/basemodel/basemodel.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_r
infer_state.mem_is_contiguous = False
alloc_mem = self.mem_manager.alloc(infer_state.total_token_num)
infer_state.mem_index = alloc_mem
infer_state.key_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.value_buffer = torch.empty((infer_state.total_token_num, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.kv_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")

init_req_to_token_indexes(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len,
max_len_in_batch, infer_state.mem_index)
Expand Down Expand Up @@ -214,8 +213,7 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_re
infer_state.mem_is_contiguous = False
alloc_mem = self.mem_manager.alloc(batch_size)
infer_state.mem_index = alloc_mem
infer_state.key_buffer = torch.empty((batch_size, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.value_buffer = torch.empty((batch_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.kv_buffer = torch.empty((batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)

infer_state.init_some_extra_state(self, input_ids)
Expand Down Expand Up @@ -272,8 +270,7 @@ def splitfuse_forward(
infer_state.mem_is_contiguous = False
alloc_mem = self.mem_manager.alloc(alloc_size)
infer_state.mem_index = alloc_mem
infer_state.key_buffer = torch.empty((alloc_size, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.value_buffer = torch.empty((alloc_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.kv_buffer = torch.empty((alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")

# decode 部分
if decode_req_num != 0:
Expand Down
3 changes: 1 addition & 2 deletions lightllm/common/basemodel/infer_struct.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def __init__(self):
self.mem_index = None
self.mem_start = None
self.mem_end = None
self.key_buffer = None
self.value_buffer = None
self.kv_buffer = None

self.is_splitfuse = False
self.return_all_prompt_logprobs = False
Expand Down
41 changes: 19 additions & 22 deletions lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,25 @@ def _ffn_norm(self, input, infer_state:InferStateInfo, layer_weight)->torch.Tens

def _pre_cache_kv(self, infer_state:InferStateInfo, layer_weight)->Tuple[torch.Tensor, torch.Tensor]:
if infer_state.mem_is_contiguous:
cache_k = infer_state.mem_manager.key_buffer[self.layer_num_][infer_state.mem_start:infer_state.mem_end, :, :]
cache_v = infer_state.mem_manager.value_buffer[self.layer_num_][infer_state.mem_start:infer_state.mem_end, :, :]
cache_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][infer_state.mem_start:infer_state.mem_end, :, :]
else:
cache_k = infer_state.key_buffer
cache_v = infer_state.value_buffer
return cache_k, cache_v
cache_kv = infer_state.kv_buffer
return cache_kv

def _get_qkv(self, input, cache_k, cache_v, infer_state:InferStateInfo, layer_weight)->Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _get_qkv(self, input, cache_kv, infer_state:InferStateInfo, layer_weight)->Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
raise Exception("need to impl")

def _post_cache_kv(self, cache_k, cache_v, infer_state:InferStateInfo, layer_weight):
def _post_cache_kv(self, cache_kv, infer_state:InferStateInfo, layer_weight):
mem_manager = infer_state.mem_manager
if not infer_state.mem_is_contiguous:
self._copy_kv_to_mem_cache(cache_k, cache_v, infer_state.mem_index, mem_manager)
self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager)
return

def _copy_kv_to_mem_cache(self, key_buffer, value_buffer, mem_index, mem_manager):
destindex_copy_kv(key_buffer, mem_index, mem_manager.key_buffer[self.layer_num_])
destindex_copy_kv(value_buffer, mem_index, mem_manager.value_buffer[self.layer_num_])
def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager):
destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_])
return

def _context_attention_kernel(self, q, k, v, infer_state:InferStateInfo, layer_weight, out=None)->torch.Tensor:
def _context_attention_kernel(self, q, kv, infer_state:InferStateInfo, layer_weight, out=None)->torch.Tensor:
raise Exception("need to impl")

def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight, out=None)->torch.Tensor:
Expand All @@ -71,11 +68,11 @@ def _ffn(self, input, infer_state:InferStateInfo, layer_weight)->torch.Tensor:
@mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why
def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embding, infer_state, layer_weight)
cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight)
q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, infer_state, layer_weight)
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight)
o = self._context_attention_kernel(q, cache_k, cache_v, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
Expand All @@ -96,10 +93,10 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight
# this impl dont to use @mark_cost_time
def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embding, infer_state, layer_weight)
cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight)
q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, infer_state, layer_weight)
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
Expand All @@ -121,10 +118,10 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
# @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why
def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight):
input1 = self._att_norm(input_embding, infer_state, layer_weight)
cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight)
q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, infer_state, layer_weight)
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def _pre_cache_kv(self, infer_state:InferStateInfo, layer_weight)->Tuple[torch.T
'''
Release kv buffer to save memory, since we allocate while kv projection.
'''
infer_state.key_buffer = None
infer_state.value_buffer = None
return None, None
infer_state.kv_buffer = None
return None


Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,5 @@ def _pre_cache_kv(self, infer_state:InferStateInfo, layer_weight)->Tuple[torch.T
'''
Release kv buffer to save memory, since we allocate while kv projection.
'''
infer_state.key_buffer = None
infer_state.value_buffer = None
return None, None
infer_state.kv_buffer = None
return None
20 changes: 20 additions & 0 deletions lightllm/common/basemodel/layer_weights/base_layer_weight.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import numpy as np
import threading


class BaseLayerWeight:
def __init__(self):
self.tp_rank_ = None
self.lock = threading.Lock()

def load_hf_weights(self, weights):
"""
Expand All @@ -30,3 +32,21 @@ def _cuda(self, cpu_tensor):
return cpu_tensor.contiguous().to(self.data_type_).cuda()
else:
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)

def _try_cat_to(self, source_tensor_names, dest_name, cat_dim, handle_func=None):
if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name):
with self.lock:
if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name):
assert all(
not getattr(self, name, None).is_cuda for name in source_tensor_names
), "all not cuda tensor"
tensors = [getattr(self, name, None) for name in source_tensor_names]
ans = torch.cat(tensors, dim=cat_dim)
if handle_func is not None:
ans = handle_func(ans)
else:
ans = self._cuda(ans)
setattr(self, dest_name, ans)
for name in source_tensor_names:
delattr(self, name)
return
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

class TransformerLayerWeight(BaseLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode):
super().__init__()
self.layer_num_ = layer_num
self.tp_rank_ = tp_rank
self.world_size_ = world_size
self.data_type_ = data_type
self.network_config_ = network_config
self.mode = mode
self.init_static_params()
return
return
3 changes: 1 addition & 2 deletions lightllm/common/basemodel/splitfuse_infer_struct.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __init__(self):
self.mem_start = None
self.mem_end = None
self.mem_index = None
self.key_buffer = None
self.value_buffer = None
self.kv_buffer = None

self.parrall_stream = torch.cuda.Stream()
self.start_event = torch.cuda.Event()
Expand Down
12 changes: 4 additions & 8 deletions lightllm/common/int8kv_mem_manager.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True)
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.key_scale_buffer = [torch.empty((size, head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.value_scale_buffer = [torch.empty((size, head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.scale_buffer = [torch.empty((size, 2 * head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)]

def _free_buffers(self):
self.key_buffer = None
self.value_buffer = None
self.key_scale_buffer = None
self.value_scale_buffer = None
self.kv_buffer = None
self.scale_buffer = None

6 changes: 2 additions & 4 deletions lightllm/common/mem_manager.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
self._init_buffers(size, dtype, head_num, head_dim, layer_num)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]

def _free_buffers(self):
self.key_buffer = None
self.value_buffer = None
self.kv_buffer = None

@torch.no_grad()
def alloc(self, need_size):
Expand Down
12 changes: 4 additions & 8 deletions lightllm/common/ppl_int8kv_mem_manager.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
group_quant_size = 8
self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.key_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.value_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.scale_buffer = [torch.empty((size, 2 * head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)]

def _free_buffers(self):
self.key_buffer = None
self.value_buffer = None
self.key_scale_buffer = None
self.value_scale_buffer = None
self.kv_buffer = None
self.scale_buffer = None
24 changes: 12 additions & 12 deletions lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,26 @@


class Baichuan13bTransformerLayerInfer(LlamaTransformerLayerInfer):
"""
"""
""" """

def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
self._bind_func()
return

def _bind_func(self):
"""
baichuan13b only support normal mode.
baichuan13b only support normal mode.
"""
self._context_attention_kernel = partial(BloomTransformerLayerInfer._context_attention_kernel, self)
self._token_attention_kernel = partial(BloomTransformerLayerInfer._token_attention_kernel, self)
return
def _get_qkv(self, input, cache_k, cache_v, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor:

def _get_qkv(self, input, cache_kv, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor:
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
torch.mm(input.view(-1, self.embed_dim_), layer_weight.k_weight_,
out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_))
torch.mm(input.view(-1, self.embed_dim_), layer_weight.v_weight_,
out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_))
return q, cache_k, cache_v

torch.mm(
input.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
return q, cache_kv
Loading

0 comments on commit 00aeedc

Please sign in to comment.