From 00aeedc349f0d233a33018948b378d88734211f9 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 25 Jan 2024 18:52:23 +0800 Subject: [PATCH] [Update] KV Merge And FFn up gate merge. (#320) make inference faster. --------- Co-authored-by: wxd000000 <874217876@qq.com> Co-authored-by: wxd000000 Co-authored-by: FlyingLaird Co-authored-by: wangzaijun --- docs/AddNewModel_CN.md | 10 +- lightllm/common/basemodel/basemodel.py | 9 +- lightllm/common/basemodel/infer_struct.py | 3 +- .../transformer_layer_infer_template.py | 41 +- ...ransformer_layer_infer_template_awquant.py | 5 +- ...transformer_layer_infer_template_wquant.py | 5 +- .../layer_weights/base_layer_weight.py | 20 + .../layer_weights/transformer_layer_weight.py | 3 +- .../basemodel/splitfuse_infer_struct.py | 3 +- lightllm/common/int8kv_mem_manager.py | 12 +- lightllm/common/mem_manager.py | 6 +- lightllm/common/ppl_int8kv_mem_manager.py | 12 +- .../layer_infer/transformer_layer_infer.py | 24 +- .../layer_infer/transformer_layer_infer.py | 31 +- .../layer_weights/transformer_layer_weight.py | 23 +- .../layer_infer/transformer_layer_infer.py | 127 +++-- .../layer_weights/transformer_layer_weight.py | 144 ++--- .../layer_infer/transformer_layer_infer.py | 36 +- .../layer_weights/transformer_layer_weight.py | 99 ++-- .../chatglm2/triton_kernel/rotary_emb.py | 127 ++++- .../layer_infer/transformer_layer_infer.py | 46 +- .../layer_weights/transformer_layer_weight.py | 84 +-- .../layer_weights/transformer_layer_weight.py | 90 ++-- .../layer_infer/transformer_layer_infer.py | 59 ++- .../layer_weights/transformer_layer_weight.py | 89 ++-- .../layer_infer/transformer_layer_infer.py | 493 +++++++++++------- .../layer_weights/transformer_layer_weight.py | 86 +-- .../models/llama/triton_kernel/rotary_emb.py | 136 ++++- .../llama/triton_kernel/silu_and_mul.py | 98 ++++ .../token_attention_nopad_att1.py | 136 +++-- .../token_attention_nopad_reduceV.py | 148 ++++-- .../layer_infer/transformer_layer_infer.py | 182 ++++--- .../layer_weights/transformer_layer_weight.py | 65 +-- .../layer_infer/transformer_layer_infer.py | 115 ++-- .../layer_weights/transformer_layer_weight.py | 103 ++-- .../layer_infer/transformer_layer_infer.py | 12 +- .../layer_weights/transformer_layer_weight.py | 89 ++-- .../layer_infer/transformer_layer_infer.py | 37 +- .../layer_weights/transformer_layer_weight.py | 80 +-- .../layer_infer/transformer_layer_infer.py | 46 +- .../layer_weights/transformer_layer_weight.py | 88 ++-- .../layer_weights/transformer_layer_weight.py | 73 ++- .../layer_infer/transformer_layer_infer.py | 85 +-- .../layer_weights/transformer_layer_weight.py | 89 ++-- .../layer_weights/transformer_layer_weight.py | 59 ++- test/model/test_settings/process_utils.py | 29 +- 46 files changed, 2114 insertions(+), 1243 deletions(-) mode change 100644 => 100755 docs/AddNewModel_CN.md mode change 100644 => 100755 lightllm/common/basemodel/basemodel.py mode change 100644 => 100755 lightllm/common/basemodel/infer_struct.py mode change 100644 => 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py mode change 100644 => 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py mode change 100644 => 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py mode change 100644 => 100755 lightllm/common/basemodel/splitfuse_infer_struct.py mode change 100644 => 100755 lightllm/common/int8kv_mem_manager.py mode change 100644 => 100755 lightllm/common/mem_manager.py mode change 100644 => 100755 lightllm/common/ppl_int8kv_mem_manager.py mode change 100644 => 100755 lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/bloom/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py mode change 100644 => 100755 lightllm/models/chatglm2/triton_kernel/rotary_emb.py mode change 100644 => 100755 lightllm/models/internlm/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/internlm/layer_weights/transformer_layer_weight.py mode change 100644 => 100755 lightllm/models/internlm2/layer_weights/transformer_layer_weight.py mode change 100644 => 100755 lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/internlm_wquant/layer_weights/transformer_layer_weight.py mode change 100644 => 100755 lightllm/models/llama/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/llama/triton_kernel/rotary_emb.py create mode 100644 lightllm/models/llama/triton_kernel/silu_and_mul.py mode change 100644 => 100755 lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/llama_awquant/layer_weights/transformer_layer_weight.py mode change 100644 => 100755 lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/mistral/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/qwen/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/qwen/layer_weights/transformer_layer_weight.py mode change 100644 => 100755 lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py mode change 100644 => 100755 lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py diff --git a/docs/AddNewModel_CN.md b/docs/AddNewModel_CN.md old mode 100644 new mode 100755 index 4230d3790..45798a176 --- a/docs/AddNewModel_CN.md +++ b/docs/AddNewModel_CN.md @@ -442,11 +442,11 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl): alpha=1.0, out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) return q - def _context_attention_kernel(self, q, k, v, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: + def _context_attention_kernel(self, q, kv, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: o_tensor = torch.empty_like(q) context_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - k.view(-1, self.tp_k_head_num_, self.head_dim_), - v.view(-1, self.tp_v_head_num_, self.head_dim_), + kv[:, 0: self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), layer_weight.tp_alibi, infer_state.b_start_loc, @@ -457,8 +457,8 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl): def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: o_tensor = torch.empty_like(q) token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.value_buffer[self.layer_num_], + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), layer_weight.tp_alibi, infer_state.b_loc, diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py old mode 100644 new mode 100755 index 61a04a752..02831d74b --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -178,8 +178,7 @@ def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_r infer_state.mem_is_contiguous = False alloc_mem = self.mem_manager.alloc(infer_state.total_token_num) infer_state.mem_index = alloc_mem - infer_state.key_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.value_buffer = torch.empty((infer_state.total_token_num, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.kv_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") init_req_to_token_indexes(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, max_len_in_batch, infer_state.mem_index) @@ -214,8 +213,7 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_re infer_state.mem_is_contiguous = False alloc_mem = self.mem_manager.alloc(batch_size) infer_state.mem_index = alloc_mem - infer_state.key_buffer = torch.empty((batch_size, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.value_buffer = torch.empty((batch_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.kv_buffer = torch.empty((batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) infer_state.init_some_extra_state(self, input_ids) @@ -272,8 +270,7 @@ def splitfuse_forward( infer_state.mem_is_contiguous = False alloc_mem = self.mem_manager.alloc(alloc_size) infer_state.mem_index = alloc_mem - infer_state.key_buffer = torch.empty((alloc_size, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.value_buffer = torch.empty((alloc_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.kv_buffer = torch.empty((alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # decode 部分 if decode_req_num != 0: diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py old mode 100644 new mode 100755 index e4fdbcecf..34c068279 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -23,8 +23,7 @@ def __init__(self): self.mem_index = None self.mem_start = None self.mem_end = None - self.key_buffer = None - self.value_buffer = None + self.kv_buffer = None self.is_splitfuse = False self.return_all_prompt_logprobs = False diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py old mode 100644 new mode 100755 index 26f9ddef6..0ea39eb5c --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -31,28 +31,25 @@ def _ffn_norm(self, input, infer_state:InferStateInfo, layer_weight)->torch.Tens def _pre_cache_kv(self, infer_state:InferStateInfo, layer_weight)->Tuple[torch.Tensor, torch.Tensor]: if infer_state.mem_is_contiguous: - cache_k = infer_state.mem_manager.key_buffer[self.layer_num_][infer_state.mem_start:infer_state.mem_end, :, :] - cache_v = infer_state.mem_manager.value_buffer[self.layer_num_][infer_state.mem_start:infer_state.mem_end, :, :] + cache_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][infer_state.mem_start:infer_state.mem_end, :, :] else: - cache_k = infer_state.key_buffer - cache_v = infer_state.value_buffer - return cache_k, cache_v + cache_kv = infer_state.kv_buffer + return cache_kv - def _get_qkv(self, input, cache_k, cache_v, infer_state:InferStateInfo, layer_weight)->Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _get_qkv(self, input, cache_kv, infer_state:InferStateInfo, layer_weight)->Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: raise Exception("need to impl") - def _post_cache_kv(self, cache_k, cache_v, infer_state:InferStateInfo, layer_weight): + def _post_cache_kv(self, cache_kv, infer_state:InferStateInfo, layer_weight): mem_manager = infer_state.mem_manager if not infer_state.mem_is_contiguous: - self._copy_kv_to_mem_cache(cache_k, cache_v, infer_state.mem_index, mem_manager) + self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager) return - def _copy_kv_to_mem_cache(self, key_buffer, value_buffer, mem_index, mem_manager): - destindex_copy_kv(key_buffer, mem_index, mem_manager.key_buffer[self.layer_num_]) - destindex_copy_kv(value_buffer, mem_index, mem_manager.value_buffer[self.layer_num_]) + def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager): + destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) return - def _context_attention_kernel(self, q, k, v, infer_state:InferStateInfo, layer_weight, out=None)->torch.Tensor: + def _context_attention_kernel(self, q, kv, infer_state:InferStateInfo, layer_weight, out=None)->torch.Tensor: raise Exception("need to impl") def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight, out=None)->torch.Tensor: @@ -71,11 +68,11 @@ def _ffn(self, input, infer_state:InferStateInfo, layer_weight)->torch.Tensor: @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embding, infer_state, layer_weight) - cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight) - q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, infer_state, layer_weight) + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) input1 = None - self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_k, cache_v, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.world_size_ > 1: @@ -96,10 +93,10 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight # this impl dont to use @mark_cost_time def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): input1 = self._att_norm(input_embding, infer_state, layer_weight) - cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight) - q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, infer_state, layer_weight) + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) input1 = None - self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) @@ -121,10 +118,10 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): # @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight): input1 = self._att_norm(input_embding, infer_state, layer_weight) - cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight) - q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, infer_state, layer_weight) + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) input1 = None - self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py old mode 100644 new mode 100755 index 7115ec7af..3806f6da3 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py @@ -36,8 +36,7 @@ def _pre_cache_kv(self, infer_state:InferStateInfo, layer_weight)->Tuple[torch.T ''' Release kv buffer to save memory, since we allocate while kv projection. ''' - infer_state.key_buffer = None - infer_state.value_buffer = None - return None, None + infer_state.kv_buffer = None + return None diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py old mode 100644 new mode 100755 index e64a080d7..024e32a90 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py @@ -31,6 +31,5 @@ def _pre_cache_kv(self, infer_state:InferStateInfo, layer_weight)->Tuple[torch.T ''' Release kv buffer to save memory, since we allocate while kv projection. ''' - infer_state.key_buffer = None - infer_state.value_buffer = None - return None, None \ No newline at end of file + infer_state.kv_buffer = None + return None \ No newline at end of file diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py index 88ce3447c..b73456ed4 100644 --- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py @@ -1,10 +1,12 @@ import torch import numpy as np +import threading class BaseLayerWeight: def __init__(self): self.tp_rank_ = None + self.lock = threading.Lock() def load_hf_weights(self, weights): """ @@ -30,3 +32,21 @@ def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda() else: return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_) + + def _try_cat_to(self, source_tensor_names, dest_name, cat_dim, handle_func=None): + if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name): + with self.lock: + if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name): + assert all( + not getattr(self, name, None).is_cuda for name in source_tensor_names + ), "all not cuda tensor" + tensors = [getattr(self, name, None) for name in source_tensor_names] + ans = torch.cat(tensors, dim=cat_dim) + if handle_func is not None: + ans = handle_func(ans) + else: + ans = self._cuda(ans) + setattr(self, dest_name, ans) + for name in source_tensor_names: + delattr(self, name) + return diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 61e7e37f6..bde7cb314 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -3,6 +3,7 @@ class TransformerLayerWeight(BaseLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode): + super().__init__() self.layer_num_ = layer_num self.tp_rank_ = tp_rank self.world_size_ = world_size @@ -10,4 +11,4 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo self.network_config_ = network_config self.mode = mode self.init_static_params() - return \ No newline at end of file + return diff --git a/lightllm/common/basemodel/splitfuse_infer_struct.py b/lightllm/common/basemodel/splitfuse_infer_struct.py old mode 100644 new mode 100755 index fa73696ac..694361394 --- a/lightllm/common/basemodel/splitfuse_infer_struct.py +++ b/lightllm/common/basemodel/splitfuse_infer_struct.py @@ -34,8 +34,7 @@ def __init__(self): self.mem_start = None self.mem_end = None self.mem_index = None - self.key_buffer = None - self.value_buffer = None + self.kv_buffer = None self.parrall_stream = torch.cuda.Stream() self.start_event = torch.cuda.Event() diff --git a/lightllm/common/int8kv_mem_manager.py b/lightllm/common/int8kv_mem_manager.py old mode 100644 new mode 100755 index bd9479d34..538b44d91 --- a/lightllm/common/int8kv_mem_manager.py +++ b/lightllm/common/int8kv_mem_manager.py @@ -8,14 +8,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True) super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] - self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] - self.key_scale_buffer = [torch.empty((size, head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)] - self.value_scale_buffer = [torch.empty((size, head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)] + self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] + self.scale_buffer = [torch.empty((size, 2 * head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)] def _free_buffers(self): - self.key_buffer = None - self.value_buffer = None - self.key_scale_buffer = None - self.value_scale_buffer = None + self.kv_buffer = None + self.scale_buffer = None diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py old mode 100644 new mode 100755 index 4c359979a..f967323c7 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -20,12 +20,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self._init_buffers(size, dtype, head_num, head_dim, layer_num) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)] - self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)] + self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)] def _free_buffers(self): - self.key_buffer = None - self.value_buffer = None + self.kv_buffer = None @torch.no_grad() def alloc(self, need_size): diff --git a/lightllm/common/ppl_int8kv_mem_manager.py b/lightllm/common/ppl_int8kv_mem_manager.py old mode 100644 new mode 100755 index f04f7edde..9840102e1 --- a/lightllm/common/ppl_int8kv_mem_manager.py +++ b/lightllm/common/ppl_int8kv_mem_manager.py @@ -9,13 +9,9 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): group_quant_size = 8 - self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] - self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] - self.key_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)] - self.value_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)] + self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] + self.scale_buffer = [torch.empty((size, 2 * head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)] def _free_buffers(self): - self.key_buffer = None - self.value_buffer = None - self.key_scale_buffer = None - self.value_scale_buffer = None + self.kv_buffer = None + self.scale_buffer = None \ No newline at end of file diff --git a/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py b/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index c681c18c5..18d408a4b --- a/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py @@ -10,26 +10,26 @@ class Baichuan13bTransformerLayerInfer(LlamaTransformerLayerInfer): - """ - """ + """ """ + def __init__(self, layer_num, tp_rank, world_size, network_config, mode): super().__init__(layer_num, tp_rank, world_size, network_config, mode) self._bind_func() return - + def _bind_func(self): """ - baichuan13b only support normal mode. + baichuan13b only support normal mode. """ self._context_attention_kernel = partial(BloomTransformerLayerInfer._context_attention_kernel, self) self._token_attention_kernel = partial(BloomTransformerLayerInfer._token_attention_kernel, self) return - - def _get_qkv(self, input, cache_k, cache_v, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor: + + def _get_qkv(self, input, cache_kv, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor: q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) - torch.mm(input.view(-1, self.embed_dim_), layer_weight.k_weight_, - out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_)) - torch.mm(input.view(-1, self.embed_dim_), layer_weight.v_weight_, - out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) - return q, cache_k, cache_v - \ No newline at end of file + torch.mm( + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + return q, cache_kv diff --git a/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py b/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 6fc3ad3b1..c22a7f232 --- a/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py @@ -6,23 +6,26 @@ class Baichuan2_7bTransformerLayerInfer(LlamaTransformerLayerInfer): - """ - """ + """ """ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) return - - def _get_qkv(self, input, cache_k, cache_v, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor: - q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) + + def _get_qkv( + self, input, cache_kv: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight + ) -> torch.Tensor: + q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_).view( + -1, self.tp_q_head_num_, self.head_dim_ + ) + torch.mm( + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) q_ = q.float() - rotary_emb_fwd(q_.view(-1, self.tp_q_head_num_, self.head_dim_).float(), infer_state.position_cos, infer_state.position_sin) + cache_k_ = cache_kv[:, 0 : self.tp_k_head_num_, :].float() + rotary_emb_fwd(q_, cache_k_, infer_state.position_cos, infer_state.position_sin) + cache_kv[:, 0 : self.tp_k_head_num_, :].copy_(cache_k_) q.copy_(q_) - torch.mm(input.view(-1, self.embed_dim_), layer_weight.k_weight_, - out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_)) - cache_k_ = cache_k.float() - rotary_emb_fwd(cache_k_, infer_state.position_cos, infer_state.position_sin) - cache_k.copy_(cache_k_) - torch.mm(input.view(-1, self.embed_dim_), layer_weight.v_weight_, - out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) - return q, cache_k, cache_v + return q, cache_kv diff --git a/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py b/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py index 390670139..fe4aa7369 100644 --- a/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py @@ -4,11 +4,12 @@ from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + class BaiChuan7bTransformerLayerWeight(LlamaTransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) return - + def _load_qkvo_weights(self, weights): # input layernorm params if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: @@ -21,17 +22,19 @@ def _load_qkvo_weights(self, weights): qkv_weights = weights[f"model.layers.{self.layer_num_}.self_attn.W_pack.weight"] split_size = qkv_weights.shape[0] // 3 q_weights, k_weights, v_weights = torch.split(qkv_weights, split_size, dim=0) - - self.q_weight_ = q_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = q_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) - self.k_weight_ = k_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = self._cuda(self.k_weight_.transpose(0, 1)) - self.v_weight_ = v_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = self._cuda(self.v_weight_.transpose(0, 1)) - + k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) + v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: - self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"][:,split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) return - diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index b3da48092..5dea5f975 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -14,8 +14,7 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl): - """ - """ + """ """ def __init__(self, layer_num, tp_rank, world_size, network_config, mode): super().__init__(layer_num, tp_rank, world_size, network_config, mode) @@ -27,68 +26,92 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode): self.head_dim_ = network_config["n_embed"] // network_config["num_attention_heads"] self.embed_dim_ = network_config["n_embed"] return - - def _att_norm(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: + + def _att_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: return layernorm_forward( input.view(-1, self.embed_dim_), weight=layer_weight.att_norm_weight_, bias=layer_weight.att_norm_bias_, - eps=self.eps_) - - def _ffn_norm(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: + eps=self.eps_, + ) + + def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: return layernorm_forward( input.view(-1, self.embed_dim_), weight=layer_weight.ffn_norm_weight_, bias=layer_weight.ffn_norm_bias_, - eps=self.eps_) - - def _get_qkv(self, input, cache_k, cache_v, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: - q = torch.addmm(layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0) - torch.addmm(layer_weight.k_bias_, input.view(-1, self.embed_dim_), layer_weight.k_weight_, beta=1.0, - alpha=1.0, out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_)) - torch.addmm(layer_weight.v_bias_, input.view(-1, self.embed_dim_), layer_weight.v_weight_, beta=1.0, - alpha=1.0, out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) - return q, cache_k, cache_v - - def _context_attention_kernel(self, q, k, v, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None)->torch.Tensor: - o_tensor = torch.empty_like(q) if out is None else out - context_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - k.view(-1, self.tp_k_head_num_, self.head_dim_), - v.view(-1, self.tp_v_head_num_, self.head_dim_), - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - layer_weight.tp_alibi, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch) + eps=self.eps_, + ) + + def _get_qkv( + self, input, cache_kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> torch.Tensor: + q = torch.addmm( + layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + ) + torch.addmm( + layer_weight.kv_bias_, + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + beta=1.0, + alpha=1.0, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None + ) -> torch.Tensor: + o_tensor = torch.empty_like(q) if out is None else out + context_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + layer_weight.tp_alibi, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) return o_tensor - - def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None)->torch.Tensor: - o_tensor = torch.empty_like(q) if out is None else out - token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.value_buffer[self.layer_num_], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - layer_weight.tp_alibi, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.total_token_num) + + def _token_attention_kernel( + self, q, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None + ) -> torch.Tensor: + o_tensor = torch.empty_like(q) if out is None else out + token_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + layer_weight.tp_alibi, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + infer_state.total_token_num, + ) return o_tensor - - def _get_o(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: - o = torch.addmm(layer_weight.o_bias_, - input.view(-1, self.tp_q_head_num_ * self.head_dim_), - layer_weight.o_weight_, - beta=1.0 / self.world_size_) + + def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: + o = torch.addmm( + layer_weight.o_bias_, + input.view(-1, self.tp_q_head_num_ * self.head_dim_), + layer_weight.o_weight_, + beta=1.0 / self.world_size_, + ) return o - - def _ffn(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor: + + def _ffn(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: ffn1_out = torch.addmm(layer_weight.ffn_1_bias_, input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_) input = None - gelu_out = torch.nn.functional.gelu(ffn1_out, approximate='tanh') + gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") ffn1_out = None - ffn2_out = torch.addmm(layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.world_size_) + ffn2_out = torch.addmm( + layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.world_size_ + ) gelu_out = None - return ffn2_out \ No newline at end of file + return ffn2_out diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index 537a3192b..a43b8a453 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -15,38 +15,36 @@ def init_static_params(self): tp_head_num = head_num // self.world_size_ tmp_alibi = self._generate_alibi(head_num, dtype=torch.float32) assert head_num % self.world_size_ == 0 - self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num: (self.tp_rank_ + 1) * tp_head_num].contiguous().cuda() + self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num : (self.tp_rank_ + 1) * tp_head_num].contiguous().cuda() return - + def load_hf_weights(self, weights): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) return - + def verify_load(self): errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.att_norm_bias_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_weight_, - self.o_bias_, - - self.ffn_norm_weight_, - self.ffn_norm_bias_, - self.ffn_1_weight_, - self.ffn_1_bias_, - self.ffn_2_weight_, - self.ffn_2_bias_, - ] + weights = [ + self.att_norm_weight_, + self.att_norm_bias_, + self.q_weight_, + self.kv_weight_, + self.q_bias_, + self.kv_bias_, + self.o_weight_, + self.o_bias_, + self.ffn_norm_weight_, + self.ffn_norm_bias_, + self.ffn_1_weight_, + self.ffn_1_bias_, + self.ffn_2_weight_, + self.ffn_2_bias_, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return + return def _load_qkvo_weights(self, weights): # input layernorm params @@ -60,48 +58,59 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["n_embed"] split_n_embed = n_embed // self.world_size_ head_num = self.network_config_["num_attention_heads"] - att_qkv_dense_weight = weights[f"h.{self.layer_num_}.self_attention.query_key_value.weight"].reshape(head_num, 3, -1, n_embed) - self.q_weight_ = self._cuda(att_qkv_dense_weight[:, - 0, - :, - :].reshape(-1, - n_embed)[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), - :].transpose(0, - 1)) - self.k_weight_ = self._cuda(att_qkv_dense_weight[:, - 1, - :, - :].reshape(-1, - n_embed)[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), - :].transpose(0, - 1)) - self.v_weight_ = self._cuda(att_qkv_dense_weight[:, - 2, - :, - :].reshape(-1, - n_embed)[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), - :].transpose(0, - 1)) + att_qkv_dense_weight = weights[f"h.{self.layer_num_}.self_attention.query_key_value.weight"].reshape( + head_num, 3, -1, n_embed + ) + self.q_weight_ = self._cuda( + att_qkv_dense_weight[:, 0, :, :] + .reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + .transpose(0, 1) + ) + self.k_weight_ = ( + att_qkv_dense_weight[:, 1, :, :] + .reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + .transpose(0, 1) + ) + self.v_weight_ = ( + att_qkv_dense_weight[:, 2, :, :] + .reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + .transpose(0, 1) + ) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + if f"h.{self.layer_num_}.self_attention.query_key_value.bias" in weights: n_embed = self.network_config_["n_embed"] split_n_embed = n_embed // self.world_size_ head_num = self.network_config_["num_attention_heads"] - att_qkv_dense_bias = weights[f"h.{self.layer_num_}.self_attention.query_key_value.bias"].reshape(head_num, 3, -1) - self.q_bias_ = self._cuda(att_qkv_dense_bias[:, 0, :].reshape(-1)[split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)]) - self.k_bias_ = self._cuda(att_qkv_dense_bias[:, 1, :].reshape(-1)[split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)]) - self.v_bias_ = self._cuda(att_qkv_dense_bias[:, 2, :].reshape(-1)[split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)]) + att_qkv_dense_bias = weights[f"h.{self.layer_num_}.self_attention.query_key_value.bias"].reshape( + head_num, 3, -1 + ) + self.q_bias_ = self._cuda( + att_qkv_dense_bias[:, 0, :].reshape(-1)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] + ) + self.k_bias_ = att_qkv_dense_bias[:, 1, :].reshape(-1)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] + self.v_bias_ = att_qkv_dense_bias[:, 2, :].reshape(-1)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] + + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) if f"h.{self.layer_num_}.self_attention.dense.weight" in weights: n_embed = self.network_config_["n_embed"] split_n_embed = n_embed // self.world_size_ - self.o_weight_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.weight"][:, - split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)].transpose(0, 1)) + self.o_weight_ = self._cuda( + weights[f"h.{self.layer_num_}.self_attention.dense.weight"][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ].transpose(0, 1) + ) if f"h.{self.layer_num_}.self_attention.dense.bias" in weights: self.o_bias_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.bias"]) - return + return def _load_ffn_weights(self, weights): if f"h.{self.layer_num_}.post_attention_layernorm.weight" in weights: @@ -114,24 +123,33 @@ def _load_ffn_weights(self, weights): n_embed = self.network_config_["n_embed"] * 4 split_n_embed = n_embed // self.world_size_ self.ffn_1_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_h_to_4h.weight"] - self.ffn_1_weight_ = self._cuda(self.ffn_1_weight_[split_n_embed * self.tp_rank_: split_n_embed * - (self.tp_rank_ + 1), :].transpose(0, 1)) + self.ffn_1_weight_ = self._cuda( + self.ffn_1_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :].transpose( + 0, 1 + ) + ) if f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias" in weights: n_embed = self.network_config_["n_embed"] * 4 split_n_embed = n_embed // self.world_size_ - self.ffn_1_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias"][split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)]) + self.ffn_1_bias_ = self._cuda( + weights[f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias"][ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] + ) if f"h.{self.layer_num_}.mlp.dense_4h_to_h.weight" in weights: n_embed = self.network_config_["n_embed"] * 4 split_n_embed = n_embed // self.world_size_ self.ffn_2_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.weight"] - self.ffn_2_weight_ = self._cuda(self.ffn_2_weight_[:, split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)].transpose(0, 1)) + self.ffn_2_weight_ = self._cuda( + self.ffn_2_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)].transpose( + 0, 1 + ) + ) if f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias" in weights: self.ffn_2_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias"]) - return + return def _generate_alibi(self, n_head, dtype=torch.float16): """ @@ -154,11 +172,12 @@ def _generate_alibi(self, n_head, dtype=torch.float16): See the License for the specific language governing permissions and limitations under the License. """ + def get_slopes(n): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start - return [start * ratio**i for i in range(n)] + return [start * ratio ** i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) @@ -172,4 +191,3 @@ def get_slopes_power_of_2(n): slopes = torch.Tensor(get_slopes(n_head)) head_alibi = slopes.to(dtype) return head_alibi - \ No newline at end of file diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 3a7dcb506..60620b953 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py @@ -14,8 +14,8 @@ class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer): - """ - """ + """ """ + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) return @@ -24,15 +24,27 @@ def swiglu(self, x): x = torch.chunk(x, 2, dim=-1) return torch.nn.functional.silu(x[0]) * x[1] - def _get_qkv(self, input_emb, cache_k, cache_v, infer_state: LlamaInferStateInfo, layer_weight:ChatGLM2TransformerLayerWeight): - q = torch.addmm(layer_weight.q_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0) - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) - torch.addmm(layer_weight.k_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.k_weight_, beta=1.0, alpha=1.0, - out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_)) - rotary_emb_fwd(cache_k, infer_state.position_cos, infer_state.position_sin) - torch.addmm(layer_weight.v_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.v_weight_, beta=1.0, alpha=1.0, - out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) - return q, cache_k, cache_v + def _get_qkv( + self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight + ): + q = torch.addmm( + layer_weight.q_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + ) + torch.addmm( + layer_weight.kv_bias_, + input_emb.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + beta=1.0, + alpha=1.0, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight): @@ -40,4 +52,4 @@ def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2Tr act_out = self.swiglu(ffn1_out) ffn1_out = None ffn2_out = torch.mm(act_out, layer_weight.down_proj) - return ffn2_out \ No newline at end of file + return ffn2_out diff --git a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py b/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py old mode 100644 new mode 100755 index f37741016..19e41304f --- a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py @@ -11,21 +11,20 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo def verify_load(self): errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_weight_, - self.ffn_norm_weight_, - self.gate_up_proj, - self.down_proj, - ] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.q_bias_, + self.kv_bias_, + self.o_weight_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return + return def _load_qkvo_weights(self, weights): # input layernorm params @@ -40,53 +39,79 @@ def _load_qkvo_weights(self, weights): tp_kv_head_dim = multi_query_group_num // self.world_size_ * head_dim split_n_embed = n_embed // self.world_size_ if f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight" in weights: - qkv_weight_ = weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight"].transpose(0, 1).contiguous().to(self.data_type_) - self.q_weight_ = qkv_weight_[:, :n_embed][:, split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + qkv_weight_ = ( + weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight"] + .transpose(0, 1) + .contiguous() + .to(self.data_type_) + ) + self.q_weight_ = qkv_weight_[:, :n_embed][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] self.q_weight_ = self._cuda(self.q_weight_) - self.k_weight_ = qkv_weight_[:, n_embed:n_embed + head_dim * multi_query_group_num] - self.k_weight_ = self._cuda(self.k_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)]) + k_weight_ = qkv_weight_[:, n_embed : n_embed + head_dim * multi_query_group_num] + self.k_weight_ = k_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)] + + v_weight_ = qkv_weight_[ + :, n_embed + multi_query_group_num * head_dim : n_embed + 2 * multi_query_group_num * head_dim + ] + self.v_weight_ = v_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)] - self.v_weight_ = qkv_weight_[:, n_embed + multi_query_group_num * head_dim : n_embed + 2 * multi_query_group_num * head_dim] - self.v_weight_ = self._cuda(self.v_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)]) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) if f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias" in weights: - qkv_bias_ = weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias"].to(self.data_type_) - self.q_bias_ = qkv_bias_[:n_embed][split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + qkv_bias_ = weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias"].to( + self.data_type_ + ) + self.q_bias_ = qkv_bias_[:n_embed][split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] self.q_bias_ = self._cuda(self.q_bias_) - self.k_bias_ = qkv_bias_[n_embed : n_embed + head_dim * multi_query_group_num] - self.k_bias_ = self._cuda(self.k_bias_[tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)]) - self.v_bias_ = qkv_bias_[n_embed + multi_query_group_num * head_dim : n_embed + 2 * multi_query_group_num * head_dim] - self.v_bias_ = self._cuda(self.v_bias_[tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)]) + k_bias_ = qkv_bias_[n_embed : n_embed + head_dim * multi_query_group_num] + self.k_bias_ = k_bias_[tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)] + v_bias_ = qkv_bias_[ + n_embed + multi_query_group_num * head_dim : n_embed + 2 * multi_query_group_num * head_dim + ] + self.v_bias_ = v_bias_[tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)] + + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) # attention output dense params if f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight" in weights: self.o_weight_ = weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight"] - self.o_weight_ = self.o_weight_[:,split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self.o_weight_.transpose(0, 1) self.o_weight_ = self._cuda(self.o_weight_) def _load_ffn_weights(self, weights): if f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: - self.ffn_norm_weight_ = weights[f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight"] + self.ffn_norm_weight_ = weights[ + f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight" + ] self.ffn_norm_weight_ = self._cuda(self.ffn_norm_weight_) # ffn params - ffn_hidden_size = self.network_config_['ffn_hidden_size'] + ffn_hidden_size = self.network_config_["ffn_hidden_size"] split_inter_size = ffn_hidden_size // self.world_size_ if f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_h_to_4h.weight" in weights: - tweights = weights[f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_h_to_4h.weight"].to(self.data_type_) + tweights = weights[f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_h_to_4h.weight"].to( + self.data_type_ + ) gate_proj = tweights[:ffn_hidden_size, :] - gate_proj = gate_proj[split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - + gate_proj = gate_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] + self.gate_proj = gate_proj.transpose(0, 1) + up_proj = tweights[ffn_hidden_size : 2 * ffn_hidden_size, :] - up_proj = up_proj[split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] + up_proj = up_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] + self.up_proj = up_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) - gate_up_proj = torch.cat([gate_proj, up_proj], dim=0).transpose(0, 1) - self.gate_up_proj = self._cuda(gate_up_proj) - if f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight" in weights: - self.down_proj = weights[f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight"].to(self.data_type_) - self.down_proj = self.down_proj[:, split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)].transpose(0, 1) + self.down_proj = weights[f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight"].to( + self.data_type_ + ) + self.down_proj = self.down_proj[ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ].transpose(0, 1) self.down_proj = self._cuda(self.down_proj) return diff --git a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py b/lightllm/models/chatglm2/triton_kernel/rotary_emb.py old mode 100644 new mode 100755 index 0d9286892..ad1d1c2cf --- a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py +++ b/lightllm/models/chatglm2/triton_kernel/rotary_emb.py @@ -6,12 +6,23 @@ @triton.jit def _rotary_kernel( - Q, Cos, Sin, - stride_qbs, stride_qh, stride_qd, - stride_cosbs, stride_cosd, - stride_sinbs, stride_sind, + Q, + K, + Cos, + Sin, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, max_total_len, - H, # N_CTX 代表要计算的上下文长度 + HEAD_Q, + HEAD_K, # N_CTX 代表要计算的上下文长度 BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -24,14 +35,31 @@ def _rotary_kernel( dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 dim_range1 = dim_range0 + 1 - off_q0 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range0[None, None, :] * stride_qd - off_q1 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range1[None, None, :] * stride_qd + + off_q0 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range0[None, None, :] * stride_qd + ) + off_q1 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range1[None, None, :] * stride_qd + ) cos_range = tl.arange(0, BLOCK_DMODEL // 2) off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - q0 = tl.load(Q + off_q0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0) - q1 = tl.load(Q + off_q1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0) + q0 = tl.load( + Q + off_q0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + other=0.0, + ) + q1 = tl.load( + Q + off_q1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + other=0.0, + ) cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) @@ -39,31 +67,90 @@ def _rotary_kernel( out0 = q0 * cos - q1 * sin out1 = q0 * sin + q1 * cos - tl.store(Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)) - tl.store(Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)) + tl.store( + Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) + ) + tl.store( + Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) + ) + + off_k0 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range0[None, None, :] * stride_kd + ) + off_k1 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range1[None, None, :] * stride_kd + ) + + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd + k0 = tl.load( + K + off_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + k1 = tl.load( + K + off_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos + + tl.store( + K + off_k0, + out_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) + tl.store( + K + off_k1, + out_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) return @torch.no_grad() -def rotary_emb_fwd(q, cos, sin): +def rotary_emb_fwd(q, k, cos, sin): total_len = q.shape[0] - head_num = q.shape[1] + head_num_q, head_num_k = q.shape[1], k.shape[1] head_dim = q.shape[2] // 2 assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + + BLOCK_SEQ = 16 BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) if head_dim >= 128: num_warps = 8 else: num_warps = 4 + + grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) _rotary_kernel[grid]( - q, cos, sin, - q.stride(0), q.stride(1), q.stride(2), - cos.stride(0), cos.stride(1), - sin.stride(0), sin.stride(1), - total_len, head_num, + q, + k, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num_q, + head_num_k, BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, diff --git a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py b/lightllm/models/internlm/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 2500c2721..10dd0c4d9 --- a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/internlm/layer_infer/transformer_layer_infer.py @@ -10,21 +10,39 @@ class InternlmTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) return - - def _get_qkv(self, input, cache_k, cache_v, infer_state:LlamaInferStateInfo, layer_weight:InternlmTransformerLayerWeight)->torch.Tensor: - q = torch.addmm(layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0) - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) - torch.addmm(layer_weight.k_bias_, input.view(-1, self.embed_dim_), layer_weight.k_weight_, beta=1.0, - alpha=1.0, out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_)) - rotary_emb_fwd(cache_k, infer_state.position_cos, infer_state.position_sin) - torch.addmm(layer_weight.v_bias_, input.view(-1, self.embed_dim_), layer_weight.v_weight_, beta=1.0, - alpha=1.0, out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) - return q, cache_k, cache_v - def _get_o(self, input, infer_state:LlamaInferStateInfo, layer_weight:InternlmTransformerLayerWeight)->torch.Tensor: - o_tensor = torch.addmm(layer_weight.o_bias_, input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_, beta=1.0 / self.world_size_) - return o_tensor \ No newline at end of file + def _get_qkv( + self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight + ) -> torch.Tensor: + q = torch.addmm( + layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + ) + torch.addmm( + layer_weight.kv_bias_, + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + beta=1.0, + alpha=1.0, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + def _get_o( + self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight + ) -> torch.Tensor: + o_tensor = torch.addmm( + layer_weight.o_bias_, + input.view(-1, self.tp_o_head_num_ * self.head_dim_), + layer_weight.o_weight_, + beta=1.0 / self.world_size_, + ) + return o_tensor diff --git a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py old mode 100644 new mode 100755 index 5169c99f1..ccb09ef11 --- a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py @@ -5,40 +5,39 @@ from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + class InternlmTransformerLayerWeight(LlamaTransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) return - + def verify_load(self): errors = "weights load not ok" - + # handle internlm 20b, which has no bias, so set q k v o bias to zero if not self.network_config_.get("bias", True): - for layer_type in ("q", "k", "v", "o"): + for layer_type in ("q", "kv", "o"): attr_name = f"{layer_type}_bias_" if hasattr(self, attr_name): continue setattr(self, attr_name, self._cuda(torch.zeros(1))) - weights = [self.att_norm_weight_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.o_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_bias_, - self.ffn_norm_weight_, - self.up_proj, - self.gate_proj, - self.down_proj - ] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.q_bias_, + self.kv_bias_, + self.o_bias_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return - + return + def _load_qkvo_weights(self, weights): # input layernorm params if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: @@ -46,38 +45,49 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] - self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: - self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][q_split_n_embed * - self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][ + q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) + ] self.q_bias_ = self._cuda(self.q_bias_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: - self.k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - self.k_weight_ = self.k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = self._cuda(self.k_weight_.transpose(0, 1)) + k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: - self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][kv_split_n_embed * - self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1)] - self.k_bias_ = self._cuda(self.k_bias_) + self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][ + kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) + ] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: - self.v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - self.v_weight_ = self.v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = self._cuda(self.v_weight_.transpose(0, 1)) + v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: - self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][kv_split_n_embed * - self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1)] - self.v_bias_ = self._cuda(self.v_bias_) + self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][ + kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) + ] + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) + # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] - self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.o_proj.bias" in weights: self.o_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"] - self.o_bias_ = self._cuda(self.o_bias_) + self.o_bias_ = self._cuda(self.o_bias_) return diff --git a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py old mode 100644 new mode 100755 index 6b17ccb16..ef98a1aec --- a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py @@ -5,40 +5,39 @@ from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + class Internlm2TransformerLayerWeight(LlamaTransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) return - + def verify_load(self): errors = "weights load not ok" - + # handle internlm 20b, which has no bias, so set q k v o bias to zero if not self.network_config_.get("bias", True): - for layer_type in ("q", "k", "v", "o"): + for layer_type in ("q", "kv", "o"): attr_name = f"{layer_type}_bias_" if hasattr(self, attr_name): continue setattr(self, attr_name, self._cuda(torch.zeros(1))) - weights = [self.att_norm_weight_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.o_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_bias_, - self.ffn_norm_weight_, - self.up_proj, - self.gate_proj, - self.down_proj - ] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.q_bias_, + self.kv_bias_, + self.o_bias_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return - + return + def _load_qkvo_weights(self, weights): # input layernorm params if f"model.layers.{self.layer_num_}.attention_norm.weight" in weights: @@ -46,49 +45,68 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) head_dim = n_embed // self.network_config_["num_attention_heads"] # q k v weights for llama if f"model.layers.{self.layer_num_}.attention.wqkv.weight" in weights: qkv_weight_ = weights[f"model.layers.{self.layer_num_}.attention.wqkv.weight"] q_groups = self.network_config_["num_attention_heads"] // self.network_config_["num_key_value_heads"] qkv_weight_ = qkv_weight_.reshape(self.network_config_["num_key_value_heads"], q_groups + 2, head_dim, -1) - q_weight_ = qkv_weight_[:, :q_groups, :, :].reshape(-1, qkv_weight_.shape[-1]) - self.q_weight_ = self._cuda(q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1):].transpose(0, 1)) + q_weight_ = qkv_weight_[:, :q_groups, :, :].reshape(-1, qkv_weight_.shape[-1]) + self.q_weight_ = self._cuda( + q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) :].transpose(0, 1) + ) + k_weight_ = qkv_weight_[:, -2, :, :].reshape(-1, qkv_weight_.shape[-1]) - self.k_weight_ = self._cuda(k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1):].transpose(0, 1)) + self.k_weight_ = k_weight_[ + kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) : + ].transpose(0, 1) v_weight_ = qkv_weight_[:, -1, :, :].reshape(-1, qkv_weight_.shape[-1]) - self.v_weight_ = self._cuda(v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1):].transpose(0, 1)) + self.v_weight_ = v_weight_[ + kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) : + ].transpose(0, 1) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + # attention output dense params if f"model.layers.{self.layer_num_}.attention.wo.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.attention.wo.weight"] - self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.attention.wo.bias" in weights: self.o_bias_ = weights[f"model.layers.{self.layer_num_}.attention.wo.bias"] - self.o_bias_ = self._cuda(self.o_bias_) + self.o_bias_ = self._cuda(self.o_bias_) return def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.ffn_norm.weight" in weights: self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.ffn_norm.weight"]) - inter_size = self.network_config_['intermediate_size'] + inter_size = self.network_config_["intermediate_size"] split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.feed_forward.w3.weight" in weights: - self.up_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w3.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.up_proj = self._cuda(self.up_proj.transpose(0, 1)) + up_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w3.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.feed_forward.w1.weight" in weights: - self.gate_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w1.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.gate_proj = self._cuda(self.gate_proj.transpose(0, 1)) + gate_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w1.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) if f"model.layers.{self.layer_num_}.feed_forward.w2.weight" in weights: - self.down_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w2.weight"][:, - split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.down_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w2.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) return - diff --git a/lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 940681249..0bac89ca8 --- a/lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py @@ -5,35 +5,48 @@ import triton from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.internlm_wquant.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeightQuantized +from lightllm.models.internlm_wquant.layer_weights.transformer_layer_weight import ( + InternlmTransformerLayerWeightQuantized, +) from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama_wquant.layer_infer.transformer_layer_infer import LlamaTransformerLayerInferWquant -class InternlmTransformerLayerInferWquant(LlamaTransformerLayerInferWquant): +class InternlmTransformerLayerInferWquant(LlamaTransformerLayerInferWquant): def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) return - def _get_qkv(self, input, cache_k, cache_v, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeightQuantized): - qkv_output = self._wquant_matmul_for_qkv(input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.qkv_weight_, - infer_state=infer_state) - - tp_k_head_dim = self.tp_k_head_num_ * self.head_dim_ - q = qkv_output[:, : -2 * tp_k_head_dim].add_(layer_weight.q_bias_) - k = qkv_output[:, -2 * tp_k_head_dim: -tp_k_head_dim].add_(layer_weight.k_bias_) - v = qkv_output[:, -tp_k_head_dim :].add_(layer_weight.v_bias_) - - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) - cache_k_ = k.view(-1, self.tp_k_head_num_, self.head_dim_) - rotary_emb_fwd(cache_k_, infer_state.position_cos, infer_state.position_sin) - cache_v_ = v.view(-1, self.tp_v_head_num_, self.head_dim_) - return q, cache_k_, cache_v_ + def _get_qkv( + self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeightQuantized + ): + q = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + quant_weight_params=layer_weight.q_weight_, + infer_state=infer_state, + bias=layer_weight.q_bias_, + ) + cache_kv = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + quant_weight_params=layer_weight.kv_weight_, + infer_state=infer_state, + bias=layer_weight.kv_bias_, + ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv - def _get_o(self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeightQuantized) -> torch.Tensor: - o_tensor = self._wquant_matmul_for_o(input, - quant_weight_params=layer_weight.o_weight_, - infer_state=infer_state, - bias=layer_weight.o_bias_ / self.world_size_) - return o_tensor \ No newline at end of file + def _get_o( + self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeightQuantized + ) -> torch.Tensor: + o_tensor = self._wquant_matmul_for_o( + input, + quant_weight_params=layer_weight.o_weight_, + infer_state=infer_state, + bias=layer_weight.o_bias_ / self.world_size_, + ) + return o_tensor diff --git a/lightllm/models/internlm_wquant/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm_wquant/layer_weights/transformer_layer_weight.py old mode 100644 new mode 100755 index 384aab2a8..26a50f6b8 --- a/lightllm/models/internlm_wquant/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm_wquant/layer_weights/transformer_layer_weight.py @@ -5,6 +5,7 @@ from lightllm.models.llama_wquant.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeightQuantized + class InternlmTransformerLayerWeightQuantized(LlamaTransformerLayerWeightQuantized): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) @@ -12,7 +13,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo def verify_load(self): errors = "weights load not ok" - + # handle internlm 20b, which has no bias, so set q k v o bias to zero if not self.network_config_.get("bias", True): for layer_type in ("q", "k", "v", "o"): @@ -21,21 +22,22 @@ def verify_load(self): continue setattr(self, attr_name, self._cuda(torch.zeros(1))) - weights = [self.att_norm_weight_, - self.qkv_weight_, - self.o_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_bias_, - self.ffn_norm_weight_, - self.gate_up_proj, - self.down_proj - ] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.q_bias_, + self.kv_bias_, + self.o_bias_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return - + return + def _load_qkvo_weights(self, weights): # input layernorm params if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: @@ -44,59 +46,56 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ - - if getattr(self, "qkv_weight_", None) is None: - self.qkv_weight_ = torch.empty(n_embed, q_split_n_embed + 2 * kv_split_n_embed, dtype=self.data_type_, device='cpu') - self.qkv_step_ = 0 + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] - q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] + q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] q_weight_ = q_weight_.transpose(0, 1).to(self.data_type_) - self.qkv_weight_[:, :q_split_n_embed] = q_weight_ - self.qkv_step_ += 1 + self.q_weight_ = self.quantize_weight(q_weight_) if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: - self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][q_split_n_embed * - self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] - self.q_bias_ = self._cuda(self.q_bias_) + self.q_bias_ = self._cuda( + weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][ + q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) + ] + ) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_) - self.qkv_weight_[:, q_split_n_embed: (q_split_n_embed + kv_split_n_embed)] = k_weight_ - self.qkv_step_ += 1 + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: - self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][kv_split_n_embed * - self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1)] - self.k_bias_ = self._cuda(self.k_bias_) + self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][ + kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) + ] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_) - self.qkv_weight_[:, (q_split_n_embed + kv_split_n_embed):(q_split_n_embed + 2 * kv_split_n_embed)] = v_weight_ - self.qkv_step_ += 1 + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_) if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: - self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][kv_split_n_embed * - self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1)] - self.v_bias_ = self._cuda(self.v_bias_) + self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][ + kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) + ] - if self.qkv_step_ == 3: - self.qkv_step_ = 0 - self.qkv_weight_ = self.quantize_weight(self.qkv_weight_) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1, handle_func=self.quantize_weight) + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] - self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self.quantize_weight(self.o_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.o_proj.bias" in weights: self.o_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"] - self.o_bias_ = self._cuda(self.o_bias_) - return \ No newline at end of file + self.o_bias_ = self._cuda(self.o_bias_) + return diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 93304a0a6..3875433ba --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -13,16 +13,20 @@ from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.splitfuse_infer_struct import SplitFuseInferStateInfo from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv from lightllm.common.basemodel import TransformerLayerInferTpl -from lightllm.models.llama.triton_kernel.splitfuse_context_flashattention_nopad import splitfuse_context_attention_fwd, splitfuse_context_attention_fwd_int8kv +from lightllm.models.llama.triton_kernel.splitfuse_context_flashattention_nopad import ( + splitfuse_context_attention_fwd, + splitfuse_context_attention_fwd_int8kv, +) + class LlamaTransformerLayerInfer(TransformerLayerInferTpl): - """ - """ + """ """ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) @@ -35,17 +39,17 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): self.embed_dim_ = network_config["hidden_size"] self._bind_func() return - + def _bind_func(self): self._bind_norm() self._bind_attention() return - + def _bind_norm(self): self._att_norm = partial(LlamaTransformerLayerInfer._att_norm, self) self._ffn_norm = partial(LlamaTransformerLayerInfer._ffn_norm, self) return - + def _bind_attention(self): self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) if "ppl_int8kv" in self.mode: @@ -55,221 +59,283 @@ def _bind_attention(self): self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "ppl_fp16_flashdecoding" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self + ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_int8kv" in self.mode: self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self) elif "triton_flashdecoding" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self + ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_gqa_attention" in self.mode: self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_gqa_attention_normal, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_gqa_flashdecoding" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding, self) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding, self + ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) else: self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - + # bind splitfuse attention if "triton_int8kv" in self.mode: - self._splitfuse_attention_kernel = partial(LlamaTransformerLayerInfer._splitfuse_attention_kernel_int8kv, self) + self._splitfuse_attention_kernel = partial( + LlamaTransformerLayerInfer._splitfuse_attention_kernel_int8kv, self + ) else: self._splitfuse_attention_kernel = partial(LlamaTransformerLayerInfer._splitfuse_attention_kernel, self) return - def _att_norm(self, input, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor: + def _att_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight + ) -> torch.Tensor: return rmsnorm_forward(input, weight=layer_weight.att_norm_weight_, eps=self.eps_) - - def _ffn_norm(self, input, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor: + + def _ffn_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight + ) -> torch.Tensor: return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_, eps=self.eps_) - def _get_qkv(self, input, cache_k, cache_v, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor: + def _get_qkv( + self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight + ) -> torch.Tensor: q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) - torch.mm(input.view(-1, self.embed_dim_), layer_weight.k_weight_, - out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_)) - rotary_emb_fwd(cache_k, infer_state.position_cos, infer_state.position_sin) - torch.mm(input.view(-1, self.embed_dim_), layer_weight.v_weight_, - out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) - return q, cache_k, cache_v - - def _context_attention_kernel(self, q, k, v, infer_state:LlamaInferStateInfo, layer_weight, out=None)->torch.Tensor: + torch.mm( + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: o_tensor = torch.empty_like(q) if out is None else out - context_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - k.view(-1, self.tp_k_head_num_, self.head_dim_), - v.view(-1, self.tp_v_head_num_, self.head_dim_), - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch) + context_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) return o_tensor - - def _splitfuse_attention_kernel(self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None) -> torch.Tensor: + + def _splitfuse_attention_kernel( + self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: o_tensor = torch.empty_like(q) if out is None else out infer_state.start_event.record(torch.cuda.default_stream()) if infer_state.decode_req_num > 0: - self._token_attention_kernel(q[0 : infer_state.decode_req_num, :], - infer_state.inner_decode_infer_status, - layer_weight, - out=o_tensor[0 : infer_state.decode_req_num, :]) + self._token_attention_kernel( + q[0 : infer_state.decode_req_num, :], + infer_state.inner_decode_infer_status, + layer_weight, + out=o_tensor[0 : infer_state.decode_req_num, :], + ) calcu_shape1 = (-1, self.tp_q_head_num_, self.head_dim_) if infer_state.prefill_req_num > 0: infer_state.parrall_stream.wait_event(infer_state.start_event) # infer_state.start_event.wait(infer_state.parrall_stream) with torch.cuda.stream(infer_state.parrall_stream): # assert torch.cuda.current_stream().cuda_stream == infer_state.parrall_stream.cuda_stream - splitfuse_context_attention_fwd(q[infer_state.decode_req_num:, :].view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.value_buffer[self.layer_num_], - o_tensor[infer_state.decode_req_num:, :].view(calcu_shape1), - infer_state.prefill_req_num, - infer_state.req_manager.req_to_token_indexs, - infer_state.prefill_b_req_idx, - infer_state.prefill_b_split_start_loc, - infer_state.prefill_b_split_seq_len, - infer_state.prefill_b_seq_len, - infer_state.prefill_max_split_seq_len_in_batch) + splitfuse_context_attention_fwd( + q[infer_state.decode_req_num :, :].view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor[infer_state.decode_req_num :, :].view(calcu_shape1), + infer_state.prefill_req_num, + infer_state.req_manager.req_to_token_indexs, + infer_state.prefill_b_req_idx, + infer_state.prefill_b_split_start_loc, + infer_state.prefill_b_split_seq_len, + infer_state.prefill_b_seq_len, + infer_state.prefill_max_split_seq_len_in_batch, + ) infer_state.end_event.record(infer_state.parrall_stream) torch.cuda.default_stream().wait_event(infer_state.end_event) # infer_state.event.wait(torch.cuda.default_stream()) # assert torch.cuda.current_stream().cuda_stream == torch.cuda.default_stream().cuda_stream - # assert torch.cuda.default_stream().cuda_stream != infer_state.parrall_stream.cuda_stream + # assert torch.cuda.default_stream().cuda_stream != infer_state.parrall_stream.cuda_stream return o_tensor - def _splitfuse_attention_kernel_int8kv(self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None) -> torch.Tensor: + def _splitfuse_attention_kernel_int8kv( + self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: o_tensor = torch.empty_like(q) if out is None else out infer_state.start_event.record(torch.cuda.default_stream()) if infer_state.decode_req_num > 0: - self._token_attention_kernel(q[0 : infer_state.decode_req_num, :], - infer_state.inner_decode_infer_status, - layer_weight, - out=o_tensor[0 : infer_state.decode_req_num, :]) + self._token_attention_kernel( + q[0 : infer_state.decode_req_num, :], + infer_state.inner_decode_infer_status, + layer_weight, + out=o_tensor[0 : infer_state.decode_req_num, :], + ) calcu_shape1 = (-1, self.tp_q_head_num_, self.head_dim_) if infer_state.prefill_req_num > 0: infer_state.parrall_stream.wait_event(infer_state.start_event) with torch.cuda.stream(infer_state.parrall_stream): - splitfuse_context_attention_fwd_int8kv(q[infer_state.decode_req_num:, :].view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.key_scale_buffer[self.layer_num_], - infer_state.mem_manager.value_buffer[self.layer_num_], - infer_state.mem_manager.value_scale_buffer[self.layer_num_], - o_tensor[infer_state.decode_req_num:, :].view(calcu_shape1), - infer_state.prefill_req_num, - infer_state.req_manager.req_to_token_indexs, - infer_state.prefill_b_req_idx, - infer_state.prefill_b_split_start_loc, - infer_state.prefill_b_split_seq_len, - infer_state.prefill_b_seq_len, - infer_state.prefill_max_split_seq_len_in_batch) + splitfuse_context_attention_fwd_int8kv( + q[infer_state.decode_req_num :, :].view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + infer_state.mem_manager.scale_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor[infer_state.decode_req_num :, :].view(calcu_shape1), + infer_state.prefill_req_num, + infer_state.req_manager.req_to_token_indexs, + infer_state.prefill_b_req_idx, + infer_state.prefill_b_split_start_loc, + infer_state.prefill_b_split_seq_len, + infer_state.prefill_b_seq_len, + infer_state.prefill_max_split_seq_len_in_batch, + ) infer_state.end_event.record(infer_state.parrall_stream) torch.cuda.default_stream().wait_event(infer_state.end_event) return o_tensor - - def _get_o(self, input, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor: + + def _get_o( + self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight + ) -> torch.Tensor: o_tensor = torch.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_) return o_tensor - def _ffn(self, input, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor: - gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_proj) - torch.nn.functional.silu(gate_out, inplace=True) - up_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.up_proj) + def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight) -> torch.Tensor: + up_gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj) + ffn1_out = silu_and_mul_fwd(up_gate_out) input = None - ffn1_out = gate_out * up_out - gate_out, up_out = None, None + up_gate_out = None ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj) ffn1_out = None return ffn2_out - - def _copy_kv_to_mem_cache_normal(self, key_buffer, value_buffer, mem_index, mem_manager): - destindex_copy_kv(key_buffer, mem_index, mem_manager.key_buffer[self.layer_num_]) - destindex_copy_kv(value_buffer, mem_index, mem_manager.value_buffer[self.layer_num_]) + + # # keep code + # def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight)->torch.Tensor: + # gate_up_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj) + # size = gate_up_out.shape[1] + # gate_out, up_out = gate_up_out[:, 0: size // 2], gate_up_out[:, size // 2:] + # torch.nn.functional.silu(gate_out, inplace=True) + # gate_out.mul_(up_out) + # input = None + # ffn2_out = torch.mm(gate_out, layer_weight.down_proj) + # gate_out, up_out = None, None + # return ffn2_out + + def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): + destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) return - - def _copy_kv_to_mem_cache_int8kv(self, key_buffer, value_buffer, mem_index, mem_manager): - destindex_copy_quantize_kv(key_buffer, - mem_index, - mem_manager.key_buffer[self.layer_num_], - mem_manager.key_scale_buffer[self.layer_num_]) - destindex_copy_quantize_kv(value_buffer, - mem_index, - mem_manager.value_buffer[self.layer_num_], - mem_manager.value_scale_buffer[self.layer_num_]) + + def _copy_kv_to_mem_cache_int8kv(self, buffer, mem_index, mem_manager): + destindex_copy_quantize_kv( + buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] + ) return - - def _copy_kv_to_mem_cache_ppl_int8kv(self, key_buffer, value_buffer, mem_index, mem_manager): + + def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager): from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_quantize_kv - destindex_copy_quantize_kv(key_buffer, - mem_index, - mem_manager.key_buffer[self.layer_num_], - mem_manager.key_scale_buffer[self.layer_num_]) - destindex_copy_quantize_kv(value_buffer, - mem_index, - mem_manager.value_buffer[self.layer_num_], - mem_manager.value_scale_buffer[self.layer_num_]) + + destindex_copy_quantize_kv( + buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] + ) return - + def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - + att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda") - token_att_fwd(q.view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch) - + token_att_fwd( + q.view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + o_tensor = torch.empty_like(q) if out is None else out - + if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch) + token_softmax_fwd( + att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch + ) att_m_tensor = None - token_att_fwd2(prob, - infer_state.mem_manager.value_buffer[self.layer_num_], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len) + token_att_fwd2( + prob, + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(calcu_shape1), + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) prob = None return o_tensor elif triton.__version__ >= "2.1.0": - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd - token_softmax_reducev_fwd(att_m_tensor, - infer_state.mem_manager.value_buffer[self.layer_num_], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.other_kv_index) + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(calcu_shape1), + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.other_kv_index, + ) return o_tensor else: raise Exception("not support triton version") - + def _token_decode_gqa_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) # 对 gqa模型进行推理优化的代码 from ..triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + o_tensor = torch.empty_like(q) if out is None else out gqa_decode_attention_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.value_buffer[self.layer_num_], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len) + q.view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(calcu_shape1), + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + ) return o_tensor def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): @@ -277,86 +343,127 @@ def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, la batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda") - token_att_fwd_int8k(q.view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.key_scale_buffer[self.layer_num_], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch) + token_att_fwd_int8k( + q.view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) prob = torch.empty_like(att_m_tensor) - token_softmax_fwd(att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch) + token_softmax_fwd( + att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch + ) att_m_tensor = None o_tensor = torch.empty_like(q) if out is None else out - token_att_fwd2_int8v(prob, - infer_state.mem_manager.value_buffer[self.layer_num_], - infer_state.mem_manager.value_scale_buffer[self.layer_num_], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch) + token_att_fwd2_int8v( + prob, + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + infer_state.mem_manager.scale_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(calcu_shape1), + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) prob = None return o_tensor - + def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): from lightllm.models.llama.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.key_buffer[self.layer_num_] - cache_v = infer_state.mem_manager.value_buffer[self.layer_num_] - return token_decode_attention_flash_decoding(q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out) - + + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ] + return token_decode_attention_flash_decoding( + q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out + ) + def _token_decode_attention_gqa_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): # 对 gqa 模型进行推理优化的代码 from ..triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.key_buffer[self.layer_num_] - cache_v = infer_state.mem_manager.value_buffer[self.layer_num_] - return gqa_token_decode_attention_flash_decoding(q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out) - + + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ] + return gqa_token_decode_attention_flash_decoding( + q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out + ) + def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) o_tensor = torch.empty_like(q) if out is None else out from lightllm_ppl_kernel import group8_int8kv_decode_attention - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - group8_int8kv_decode_attention(o_tensor.view(calcu_shape1), - q.view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.key_scale_buffer[self.layer_num_], - infer_state.mem_manager.value_buffer[self.layer_num_], - infer_state.mem_manager.value_scale_buffer[self.layer_num_], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch) - + + # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, + # at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) + group8_int8kv_decode_attention( + o_tensor.view(calcu_shape1), + q.view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + infer_state.mem_manager.scale_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + return o_tensor - + def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) o_tensor = torch.empty_like(q) if out is None else out from lightllm_ppl_fp16_kernel import fp16_decode_attention - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - fp16_decode_attention(o_tensor.view(calcu_shape1), - 1.0 / (self.head_dim_**0.5), - q.view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], - infer_state.mem_manager.value_buffer[self.layer_num_], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch) - + + # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, + # at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) + fp16_decode_attention( + o_tensor.view(calcu_shape1), + 1.0 / (self.head_dim_ ** 0.5), + q.view(calcu_shape1), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + return o_tensor - - def _token_decode_attention_ppl_fp16_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): + + def _token_decode_attention_ppl_fp16_flashdecoding( + self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None + ): from lightllm.models.llama.triton_kernel.ppl_fp16_flash_decoding import token_decode_attention_flash_decoding - cache_k = infer_state.mem_manager.key_buffer[self.layer_num_] - cache_v = infer_state.mem_manager.value_buffer[self.layer_num_] - return token_decode_attention_flash_decoding(q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out) \ No newline at end of file + + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ] + return token_decode_attention_flash_decoding( + q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out + ) diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index ecdec616f..9035aa4a2 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -13,23 +13,22 @@ def load_hf_weights(self, weights): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) return - + def verify_load(self): errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.o_weight_, - self.ffn_norm_weight_, - self.up_proj, - self.gate_proj, - self.down_proj - ] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return - + return + def _load_qkvo_weights(self, weights): # input layernorm params if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: @@ -37,49 +36,64 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] - self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: - self.k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - self.k_weight_ = self.k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = self._cuda(self.k_weight_.transpose(0, 1)) + k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: - self.v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - self.v_weight_ = self.v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = self._cuda(self.v_weight_.transpose(0, 1)) - + v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) + # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] - self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + return - + def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: - self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]) - - inter_size = self.network_config_['intermediate_size'] + self.ffn_norm_weight_ = self._cuda( + weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] + ) + + inter_size = self.network_config_["intermediate_size"] split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - self.up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.up_proj = self._cuda(self.up_proj.transpose(0, 1)) + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - self.gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.gate_proj = self._cuda(self.gate_proj.transpose(0, 1)) + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: - self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][:, - split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/rotary_emb.py b/lightllm/models/llama/triton_kernel/rotary_emb.py old mode 100644 new mode 100755 index 2eb09fca4..3f4d3d548 --- a/lightllm/models/llama/triton_kernel/rotary_emb.py +++ b/lightllm/models/llama/triton_kernel/rotary_emb.py @@ -6,12 +6,23 @@ @triton.jit def _rotary_kernel( - Q, Cos, Sin, - stride_qbs, stride_qh, stride_qd, - stride_cosbs, stride_cosd, - stride_sinbs, stride_sind, + Q, + K, + Cos, + Sin, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, max_total_len, - H, # N_CTX 代表要计算的上下文长度 + HEAD_Q, + HEAD_K, # N_CTX 代表要计算的上下文长度 BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -25,13 +36,29 @@ def _rotary_kernel( dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) dim_range1 = tl.arange(BLOCK_DMODEL // 2, BLOCK_DMODEL) - off_q0 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range0[None, None, :] * stride_qd - off_q1 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range1[None, None, :] * stride_qd + off_q0 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range0[None, None, :] * stride_qd + ) + off_q1 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range1[None, None, :] * stride_qd + ) off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - q0 = tl.load(Q + off_q0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0) - q1 = tl.load(Q + off_q1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0) + q0 = tl.load( + Q + off_q0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + other=0.0, + ) + q1 = tl.load( + Q + off_q1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + other=0.0, + ) cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) @@ -39,32 +66,89 @@ def _rotary_kernel( out0 = q0 * cos - q1 * sin out1 = q0 * sin + q1 * cos - tl.store(Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)) - tl.store(Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)) + tl.store( + Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) + ) + tl.store( + Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) + ) + + off_k0 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range0[None, None, :] * stride_kd + ) + off_k1 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range1[None, None, :] * stride_kd + ) + + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + + k0 = tl.load( + K + off_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + k1 = tl.load( + K + off_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos + tl.store( + K + off_k0, + out_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) + tl.store( + K + off_k1, + out_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) return @torch.no_grad() -def rotary_emb_fwd(q, cos, sin): +def rotary_emb_fwd(q, k, cos, sin): total_len = q.shape[0] - head_num = q.shape[1] + head_num_q, head_num_k = q.shape[1], k.shape[1] head_dim = q.shape[2] assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + + BLOCK_SEQ = 16 BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) if head_dim >= 128: num_warps = 8 else: num_warps = 4 + grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) _rotary_kernel[grid]( - q, cos, sin, - q.stride(0), q.stride(1), q.stride(2), - cos.stride(0), cos.stride(1), - sin.stride(0), sin.stride(1), - total_len, head_num, + q, + k, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num_q, + head_num_k, BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, @@ -76,8 +160,8 @@ def rotary_emb_fwd(q, cos, sin): def torch_rotary_emb(x, cos, sin): seq_len, h, dim = x.shape - x0 = x[:, :, 0: dim // 2] - x1 = x[:, :, dim // 2: dim] + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] cos = cos.view((seq_len, 1, dim // 2)) sin = sin.view((seq_len, 1, dim // 2)) o0 = x0 * cos - x1 * sin @@ -85,13 +169,13 @@ def torch_rotary_emb(x, cos, sin): return torch.cat((o0, o1), dim=-1) -def test_rotary_emb(SEQ_LEN, H, D, dtype, eps=1e-5, device='cuda'): +def test_rotary_emb(SEQ_LEN, H, D, dtype, eps=1e-5, device="cuda"): # create data x_shape = (SEQ_LEN, H, D) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") cos_shape = (SEQ_LEN, D // 2) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") # forward pass y_tri = torch_rotary_emb(x, cos, sin) rotary_emb_fwd(x, cos, sin) diff --git a/lightllm/models/llama/triton_kernel/silu_and_mul.py b/lightllm/models/llama/triton_kernel/silu_and_mul.py new file mode 100644 index 000000000..6ae26a9d5 --- /dev/null +++ b/lightllm/models/llama/triton_kernel/silu_and_mul.py @@ -0,0 +1,98 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _silu_and_mul_kernel( + input_ptr, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_m, + size_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + tid = tl.program_id(0) + input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M) + output_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M) + + pid = tl.program_id(1) + input_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) + output_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) + + up_offsets = input_m_offsets[:, None] * stride_input_m + (input_n_offsets[None, :] + size_n) * stride_input_n + gate_offsets = input_m_offsets[:, None] * stride_input_m + input_n_offsets[None, :] * stride_input_n + res_offsets = output_m_offsets[:, None] * stride_output_m + output_n_offsets[None, :] * stride_output_n + + up = tl.load( + input_ptr + up_offsets, + mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None], + other=0.0, + ) + gate = tl.load( + input_ptr + gate_offsets, + mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None], + other=0.0, + ).to(tl.float32) + + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(tl.float16) + + tl.store( + input_ptr + res_offsets, + up * gate, + mask=(output_n_offsets < size_n)[None, :] * (output_m_offsets < size_m)[:, None], + ) + + +def silu_and_mul_fwd(input): + stride_input_m = input.stride(0) + stride_input_n = input.stride(1) + stride_output_m = input.stride(0) + stride_output_n = input.stride(1) + size_m = input.shape[0] + size_n = input.shape[-1] // 2 + BLOCK_M = 128 + BLOCK_N = 128 + grid = ( + triton.cdiv(size_m, BLOCK_M), + triton.cdiv(size_n, BLOCK_N), + ) + _silu_and_mul_kernel[grid]( + input, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_m, + size_n, + BLOCK_M, + BLOCK_N, + ) + return input[:, 0 : (input.shape[-1] // 2)] + + +def torch_silu_and_mul(input: torch.Tensor): + return torch.nn.functional.silu(input[:, 0 : (input.shape[-1] // 2)]) * input[:, (input.shape[-1] // 2) :] + + +def test_silu_and_mul(M, N, dtype, device="cuda"): + # create data + X = torch.randn((M, N), dtype=dtype, device=device) + + # run + y_tri = silu_and_mul_fwd(X) + y_ref = torch_silu_and_mul(X) + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + return + + +# test_silu_and_mul(16, 4096, torch.float16, device='cuda') diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index 946b8d365..19a135b33 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -7,20 +7,32 @@ @triton.jit def _fwd_kernel_token_att1( - Q, K, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, + Q, + K, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, Att_Out, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - att_stride_h, att_stride_bs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, kv_group_num, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_n = tl.program_id(2) - + cur_kv_head = cur_head // kv_group_num offs_d = tl.arange(0, BLOCK_DMODEL) @@ -41,8 +53,11 @@ def _fwd_kernel_token_att1( for start_mark in range(0, block_mask, 1): q = tl.load(Q + off_q + start_mark) offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, other=0) + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -65,19 +80,31 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK)) kv_group_num = q.shape[1] // k.shape[1] - + if kv_group_num == 1: num_warps = 4 else: num_warps = 2 _fwd_kernel_token_att1[grid]( - q, k, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, + q, + k, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, att_out, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - att_out.stride(0), att_out.stride(1), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), kv_group_num=kv_group_num, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -89,20 +116,38 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen @triton.jit def _fwd_kernel_token_att1_int8( - Q, K, K_scale, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, + Q, + K, + K_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, Att_Out, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_ksbs, stride_ksh, stride_ksd, - att_stride_h, att_stride_bs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_ksbs, + stride_ksh, + stride_ksd, + att_stride_h, + att_stride_bs, + kv_group_num, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_n = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) @@ -121,10 +166,14 @@ def _fwd_kernel_token_att1_int8( for start_mark in range(0, block_mask, 1): q = tl.load(Q + off_q + start_mark) offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) - off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - off_ks = k_loc[:, None] * stride_ksbs + cur_head * stride_ksh + off_ks = k_loc[:, None] * stride_ksbs + cur_kv_head * stride_ksh k_scale = tl.load(K_scale + off_ks, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k * k_scale, 1) att_value *= sm_scale @@ -146,20 +195,39 @@ def token_att_fwd_int8k(q, k, k_scale, att_out, Req_to_tokens, B_req_idx, B_Star grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) - num_warps = 4 if Lk <= 64 else 8 - num_warps = 2 + kv_group_num = q.shape[1] // k.shape[1] + if kv_group_num == 1: + num_warps = 4 + else: + num_warps = 2 _fwd_kernel_token_att1_int8[grid]( - q, k, k_scale, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, + q, + k, + k_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, att_out, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - k_scale.stride(0), k_scale.stride(1), k_scale.stride(2), - att_out.stride(0), att_out.stride(1), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + k_scale.stride(0), + k_scale.stride(1), + k_scale.stride(2), + att_out.stride(0), + att_out.stride(1), + kv_group_num=kv_group_num, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, ) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py index 5682c4ea6..dcf3e60b7 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py @@ -3,27 +3,39 @@ import triton import triton.language as tl + @triton.jit def _fwd_kernel_token_att2( - Prob, V, Out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_ph, stride_pbs, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, + Prob, + V, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_ph, + stride_pbs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, kv_group_num, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - + cur_kv_head = cur_head // kv_group_num offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_start_index = 0 - cur_batch_end_index = cur_batch_seq_len cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) @@ -35,8 +47,14 @@ def _fwd_kernel_token_att2( for start_n in range(0, cur_batch_seq_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load(Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v_loc = tl.load( + Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) acc += tl.sum(p_value[:, None] * v_value, 0) acc = acc.to(tl.float16) @@ -45,6 +63,7 @@ def _fwd_kernel_token_att2( tl.store(out_ptrs, acc) return + @torch.no_grad() def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen): if triton.__version__ >= "2.1.0": @@ -55,15 +74,27 @@ def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen grid = (batch, head) num_warps = 4 dim = v.shape[-1] - + kv_group_num = prob.shape[0] // v.shape[1] _fwd_kernel_token_att2[grid]( - prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - prob.stride(0), prob.stride(1), - v.stride(0), v.stride(1), v.stride(2), - out.stride(0), out.stride(1), out.stride(2), + prob, + v, + out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), kv_group_num=kv_group_num, BLOCK_DMODEL=dim, BLOCK_N=BLOCK, @@ -72,40 +103,68 @@ def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen ) return + @triton.jit def _fwd_kernel_token_att2_int8v( - Prob, V, V_scale, Out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_ph, stride_pbs, - stride_vbs, stride_vh, stride_vd, - stride_vsbs, stride_vsh, stride_vsd, - stride_obs, stride_oh, stride_od, + Prob, + V, + V_scale, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_ph, + stride_pbs, + stride_vbs, + stride_vh, + stride_vd, + stride_vsbs, + stride_vsh, + stride_vsd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) + cur_kv_head = cur_head // kv_group_num + offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_start_index = 0 - cur_batch_end_index = cur_batch_seq_len cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_head * stride_vh + offs_d[None, :] * stride_vd - vs_offs = cur_head * stride_vsh + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + vs_offs = cur_kv_head * stride_vsh acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) for start_n in range(0, cur_batch_seq_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load(Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) - vs_value = tl.load(V_scale + vs_offs + v_loc[:, None] * stride_vsbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v_loc = tl.load( + Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) + vs_value = tl.load( + V_scale + vs_offs + v_loc[:, None] * stride_vsbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) acc += tl.sum(p_value[:, None] * v_value * vs_value, 0) acc = acc.to(tl.float16) @@ -121,18 +180,35 @@ def token_att_fwd2_int8v(prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Star BLOCK = triton.next_power_of_2(max_len_in_batch) else: BLOCK = 512 - batch, head = B_req_idx.shape[0], v.shape[1] + batch, head = B_req_idx.shape[0], prob.shape[0] grid = (batch, head) num_warps = 4 dim = v.shape[-1] + kv_group_num = prob.shape[0] // v.shape[1] _fwd_kernel_token_att2_int8v[grid]( - prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - prob.stride(0), prob.stride(1), - v.stride(0), v.stride(1), v.stride(2), - v_scale.stride(0), v_scale.stride(1), v_scale.stride(2), - out.stride(0), out.stride(1), out.stride(2), + prob, + v, + v_scale, + out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + v_scale.stride(0), + v_scale.stride(1), + v_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv_group_num=kv_group_num, BLOCK_DMODEL=dim, BLOCK_N=BLOCK, num_warps=num_warps, diff --git a/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py b/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 0110c76ac..c27c3a532 --- a/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py @@ -7,18 +7,24 @@ import triton from functools import partial -from lightllm.models.llama_awquant.layer_weights.transformer_layer_weight import LlamaTransformerLayerActivationWeightQuantized +from lightllm.models.llama_awquant.layer_weights.transformer_layer_weight import ( + LlamaTransformerLayerActivationWeightQuantized, +) from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.common.basemodel import TransformerLayerInferActivationWeightQuantTpl -from lightllm.common.basemodel.cuda_kernel.ppl_awquant import matmul_i8_i32_ppl, skiprmsnorm_ppl, channel_token_dequant_i32_fp16_ppl +from lightllm.common.basemodel.cuda_kernel.ppl_awquant import ( + matmul_i8_i32_ppl, + skiprmsnorm_ppl, + channel_token_dequant_i32_fp16_ppl, +) from lightllm.common.basemodel.cuda_kernel.ppl_awquant import dynamic_channelwise_quant_fp16_i8_ppl, gatesilu_i32_i8_ppl from lightllm.utils.infer_utils import mark_cost_time - + + class LlamaTransformerLayerInferAWquant(TransformerLayerInferActivationWeightQuantTpl): - """ - """ + """ """ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) @@ -29,18 +35,18 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): self.tp_o_head_num_ = self.tp_q_head_num_ self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] self.embed_dim_ = network_config["hidden_size"] - - self.inter_dim_ = network_config['intermediate_size'] + + self.inter_dim_ = network_config["intermediate_size"] self._bind_func() return - + def _bind_func(self): self._bind_norm() - self._bind_matmul() - self._bind_silu() + self._bind_matmul() + self._bind_silu() LlamaTransformerLayerInfer._bind_attention(self) return - + def _bind_norm(self): if "ppl_int8_activation_weight" in self.mode: self._awquant_att_norm = partial(LlamaTransformerLayerInferAWquant._awquant_att_norm_ppl_int8, self) @@ -48,13 +54,21 @@ def _bind_norm(self): else: raise Exception(f"error mode {self.mode}") return - + def _bind_matmul(self): if "ppl_int8_activation_weight" in self.mode: - self._awquant_matmul_for_qkv = partial(LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self) - self._awquant_matmul_for_o = partial(LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self) - self._awquant_matmul_for_ffn_up = partial(LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant, self) - self._awquant_matmul_for_ffn_down = partial(LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self) + self._awquant_matmul_for_qkv = partial( + LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self + ) + self._awquant_matmul_for_o = partial( + LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self + ) + self._awquant_matmul_for_ffn_up = partial( + LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant, self + ) + self._awquant_matmul_for_ffn_down = partial( + LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self + ) if self.tp_rank_ == 0 and self.layer_num_ == 0: print("model use ppl_int8_activation_weight kernel") else: @@ -69,58 +83,88 @@ def _bind_silu(self): raise Exception(f"error mode {self.mode}") return - def _get_qkv(self, input, cache_k, cache_v, token_scale, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerActivationWeightQuantized)->torch.Tensor: - q = self._awquant_matmul_for_qkv(input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.q_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale) - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) - - out = self._awquant_matmul_for_qkv(input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.k_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale) - cache_k_ = out.view(-1, self.tp_k_head_num_, self.head_dim_) - rotary_emb_fwd(cache_k_, infer_state.position_cos, infer_state.position_sin) - out = self._awquant_matmul_for_qkv(input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.v_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale) - cache_v_ = out.view(-1, self.tp_v_head_num_, self.head_dim_) - return q, cache_k_, cache_v_ - - def _get_o(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerActivationWeightQuantized) -> torch.Tensor: + def _get_qkv( + self, + input, + cache_kv, + token_scale, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerActivationWeightQuantized, + ) -> torch.Tensor: + q = self._awquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + quant_weight_params=layer_weight.q_weight_, + is_prefill=infer_state.is_prefill, + token_scale=token_scale, + ) + cache_kv = self._awquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + quant_weight_params=layer_weight.kv_weight_, + is_prefill=infer_state.is_prefill, + token_scale=token_scale, + ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + out = self._awquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + quant_weight_params=layer_weight.v_weight_, + is_prefill=infer_state.is_prefill, + token_scale=token_scale, + ) + cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] = out.view( + -1, self.tp_v_head_num_, self.head_dim_ + ) + return q, cache_kv + + def _get_o( + self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerActivationWeightQuantized + ) -> torch.Tensor: o_tensor = torch.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_) return o_tensor - def _ffn(self, input, token_scale, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerActivationWeightQuantized) -> torch.Tensor: - gate_out = self._awquant_matmul_for_ffn_up(input.view(-1, self.embed_dim_), - layer_weight.gate_proj, - is_prefill=infer_state.is_prefill,) - up_out = self._awquant_matmul_for_ffn_up(input.view(-1, self.embed_dim_), - layer_weight.up_proj, - is_prefill=infer_state.is_prefill,) + def _ffn( + self, + input, + token_scale, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerActivationWeightQuantized, + ) -> torch.Tensor: + gate_out = self._awquant_matmul_for_ffn_up( + input.view(-1, self.embed_dim_), + layer_weight.gate_proj, + is_prefill=infer_state.is_prefill, + ) + up_out = self._awquant_matmul_for_ffn_up( + input.view(-1, self.embed_dim_), + layer_weight.up_proj, + is_prefill=infer_state.is_prefill, + ) input = None _, gate_proj_scale = layer_weight.gate_proj _, up_proj_scale = layer_weight.up_proj - ffn1_out, ffn1_out_scale = self._awquant_silu(gate_out, up_out, - gate_proj_scale, up_proj_scale, token_scale) + ffn1_out, ffn1_out_scale = self._awquant_silu(gate_out, up_out, gate_proj_scale, up_proj_scale, token_scale) gate_out, up_out = None, None - ffn2_out = self._awquant_matmul_for_ffn_down(ffn1_out, layer_weight.down_proj, - is_prefill=infer_state.is_prefill, - token_scale=ffn1_out_scale) + ffn2_out = self._awquant_matmul_for_ffn_down( + ffn1_out, layer_weight.down_proj, is_prefill=infer_state.is_prefill, token_scale=ffn1_out_scale + ) ffn1_out = None return ffn2_out - @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why + @mark_cost_time( + "trans context flash forward time cost" + ) # dont to remove this, will make performence down, did not know why def _context_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight) - cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight) - q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, token_scale, infer_state, layer_weight) + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) input1 = None - self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_k, cache_v, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.world_size_ > 1: @@ -128,7 +172,9 @@ def _context_attention(self, input_embding, infer_state: LlamaInferStateInfo, la input_embding.add_(o.view(-1, self.embed_dim_)) return - @mark_cost_time("trans context ffn forward time cost") # dont to remove this, will make performence down, did not know why + @mark_cost_time( + "trans context ffn forward time cost" + ) # dont to remove this, will make performence down, did not know why def _context_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight): input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) @@ -141,10 +187,10 @@ def _context_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_w # this impl dont to use @mark_cost_time def _token_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight) - cache_k, cache_v = self._pre_cache_kv(infer_state, layer_weight) - q, cache_k, cache_v = self._get_qkv(input1, cache_k, cache_v, token_scale, infer_state, layer_weight) + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) input1 = None - self._post_cache_kv(cache_k, cache_v, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) @@ -163,10 +209,12 @@ def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_wei input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return - def _awquant_matmul_ppl_int8_quant_dequant(self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False): + def _awquant_matmul_ppl_int8_quant_dequant( + self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False + ): if input.dtype == torch.float16: input, token_scale = dynamic_channelwise_quant_fp16_i8_ppl(input.transpose(0, 1)) - assert has_act == False + assert has_act is False if is_prefill: qweight, qscale = quant_weight_params out = matmul_i8_i32_ppl(input, qweight) @@ -180,8 +228,10 @@ def _awquant_matmul_ppl_int8_quant_dequant(self, input, quant_weight_params, is_ out.add_(bias) return out - def _awquant_matmul_ppl_int8_quant(self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False): - assert has_act == False + def _awquant_matmul_ppl_int8_quant( + self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False + ): + assert has_act is False if is_prefill: qweight, qscale = quant_weight_params out = matmul_i8_i32_ppl(input, qweight) @@ -194,13 +244,13 @@ def _awquant_matmul_ppl_int8_quant(self, input, quant_weight_params, is_prefill, out.add_(bias) return out - def _awquant_att_norm_ppl_int8(self, input, infer_state:LlamaInferStateInfo, layer_weight): - if getattr(infer_state, 'skip', None) is None: + def _awquant_att_norm_ppl_int8(self, input, infer_state: LlamaInferStateInfo, layer_weight): + if getattr(infer_state, "skip", None) is None: infer_state.skip = torch.zeros_like(input) return skiprmsnorm_ppl(input, layer_weight.att_norm_weight_, skip=infer_state.skip) - def _awquant_ffn_norm_ppl_int8(self, input, infer_state:LlamaInferStateInfo, layer_weight): + def _awquant_ffn_norm_ppl_int8(self, input, infer_state: LlamaInferStateInfo, layer_weight): return skiprmsnorm_ppl(input, layer_weight.ffn_norm_weight_, skip=infer_state.skip) def _awquant_silu_ppl_int8(self, x, y, x_scale, y_scale, token_scale): - return gatesilu_i32_i8_ppl(x, y, x_scale, y_scale, token_scale) \ No newline at end of file + return gatesilu_i32_i8_ppl(x, y, x_scale, y_scale, token_scale) diff --git a/lightllm/models/llama_awquant/layer_weights/transformer_layer_weight.py b/lightllm/models/llama_awquant/layer_weights/transformer_layer_weight.py old mode 100644 new mode 100755 index c78ecea9f..33fe93f3c --- a/lightllm/models/llama_awquant/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama_awquant/layer_weights/transformer_layer_weight.py @@ -12,7 +12,7 @@ class LlamaTransformerLayerActivationWeightQuantized(TransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) self.init_quant_mode() - + def init_quant_mode(self): if "ppl_int8_activation_weight" in self.mode: self.quantize_weight = partial(dynamic_channelwise_quant_fp16_i8_ppl, tp_rank=self.tp_rank_) @@ -26,14 +26,13 @@ def verify_load(self): weights = [ self.att_norm_weight_, self.q_weight_, - self.k_weight_, - self.v_weight_, + self.kv_weight_, self.o_weight_, self.ffn_norm_weight_, - self.up_proj, self.gate_proj, - self.down_proj - ] + self.up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors return @@ -45,55 +44,63 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] - q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] + q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] q_weight_ = q_weight_.transpose(0, 1).to(self.data_type_) self.q_weight_ = self.quantize_weight(q_weight_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_) - self.k_weight_ = self.quantize_weight(k_weight_) + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_) if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_) - self.v_weight_ = self.quantize_weight(v_weight_) + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1, handle_func=self.quantize_weight) # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] - o_weight_ = o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + o_weight_ = o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self._cuda(o_weight_.transpose(0, 1)) return - + def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: - self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]) - - inter_size = self.network_config_['intermediate_size'] + self.ffn_norm_weight_ = self._cuda( + weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] + ) + + inter_size = self.network_config_["intermediate_size"] split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - up_proj = up_proj.transpose(0, 1).to(self.data_type_) - self.up_proj = self.quantize_weight(up_proj) + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = self.quantize_weight(up_proj.transpose(0, 1).to(self.data_type_)) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - gate_proj = gate_proj.transpose(0, 1).to(self.data_type_) - self.gate_proj = self.quantize_weight(gate_proj) + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = self.quantize_weight(gate_proj.transpose(0, 1).to(self.data_type_)) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: - down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][:, - split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.down_proj = self.quantize_weight(down_proj.transpose(0, 1)) return diff --git a/lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index a78d41a52..9db935f89 --- a/lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py @@ -14,17 +14,21 @@ from lightllm.common.basemodel import TransformerLayerInferWeightQuantTpl from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import matmul_quantize_int8 from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int8 import matmul_dequantize_int8 -from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import matmul_dequantize_int4_s1, matmul_dequantize_int4_s2, matmul_dequantize_int4_gptq +from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import ( + matmul_dequantize_int4_s1, + matmul_dequantize_int4_s2, + matmul_dequantize_int4_gptq, +) from lightllm.common.basemodel.cuda_kernel.lmdeploy_wquant import matmul_dequantize_int4_lmdeploy from lightllm.common.basemodel.cuda_kernel.ppl_wquant import matmul_dequantize_int4_ppl from lightllm.utils.infer_utils import mark_cost_time from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) - + + class LlamaTransformerLayerInferWquant(TransformerLayerInferWeightQuantTpl): - """ - """ + """ """ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) @@ -35,17 +39,17 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): self.tp_o_head_num_ = self.tp_q_head_num_ self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] self.embed_dim_ = network_config["hidden_size"] - - self.inter_dim_ = network_config['intermediate_size'] + + self.inter_dim_ = network_config["intermediate_size"] self._bind_func() return - + def _bind_func(self): self._bind_matmul() LlamaTransformerLayerInfer._bind_norm(self) LlamaTransformerLayerInfer._bind_attention(self) return - + def _bind_matmul(self): if "triton_int8weight" in self.mode: func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_triton_int8weight_only_quant, self) @@ -83,47 +87,52 @@ def _bind_matmul(self): raise Exception(f"error mode {self.mode}") return - def _get_qkv(self, input, cache_k, cache_v, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized): - qkv_output = self._wquant_matmul_for_qkv(input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.qkv_weight_, - infer_state=infer_state) - - tp_k_head_dim = self.tp_k_head_num_ * self.head_dim_ - q = qkv_output[:, : -2 * tp_k_head_dim] - k = qkv_output[:, -2 * tp_k_head_dim: -tp_k_head_dim] - v = qkv_output[:, -tp_k_head_dim :] - - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) - cache_k_ = k.view(-1, self.tp_k_head_num_, self.head_dim_) - rotary_emb_fwd(cache_k_, infer_state.position_cos, infer_state.position_sin) - cache_v_ = v.view(-1, self.tp_v_head_num_, self.head_dim_) - return q, cache_k_, cache_v_ - - def _get_o(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized) -> torch.Tensor: - o_tensor = self._wquant_matmul_for_o(input, - quant_weight_params=layer_weight.o_weight_, - infer_state=infer_state) + def _get_qkv( + self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized + ): + q = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), quant_weight_params=layer_weight.q_weight_, infer_state=infer_state + ) + cache_kv = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), quant_weight_params=layer_weight.kv_weight_, infer_state=infer_state + ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + def _get_o( + self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized + ) -> torch.Tensor: + o_tensor = self._wquant_matmul_for_o(input, quant_weight_params=layer_weight.o_weight_, infer_state=infer_state) return o_tensor - def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized) -> torch.Tensor: - gate_up_output = self._wquant_matmul_for_ffn_up(input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.gate_up_proj, - infer_state=infer_state) + def _ffn( + self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized + ) -> torch.Tensor: + gate_up_output = self._wquant_matmul_for_ffn_up( + input.view(-1, self.embed_dim_), quant_weight_params=layer_weight.gate_up_proj, infer_state=infer_state + ) input = None tp_inter_dim = self.inter_dim_ // self.world_size_ gate_up_output = gate_up_output.view(-1, 2, tp_inter_dim) torch.nn.functional.silu(gate_up_output[:, 0], inplace=True) ffn1_out = gate_up_output[:, 0] * gate_up_output[:, 1] gate_up_output = None - ffn2_out = self._wquant_matmul_for_ffn_down(ffn1_out, - quant_weight_params=layer_weight.down_proj, - infer_state=infer_state) + ffn2_out = self._wquant_matmul_for_ffn_down( + ffn1_out, quant_weight_params=layer_weight.down_proj, infer_state=infer_state + ) ffn1_out = None return ffn2_out - - def _wquant_matmul_triton_int8weight_only_quant(self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False): - assert has_act == False - if not infer_state.is_splitfuse and infer_state.is_prefill: + + def _wquant_matmul_triton_int8weight_only_quant( + self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False + ): + assert has_act is False + if infer_state.is_splitfuse is False and infer_state.is_prefill: qweight, scale = quant_weight_params out = matmul_dequantize_int8(input, qweight, scale, out=out) else: @@ -134,10 +143,12 @@ def _wquant_matmul_triton_int8weight_only_quant(self, input, quant_weight_params else: out.add_(bias) return out - - def _wquant_matmul_triton_int4weight_only_quant(self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False): - assert has_act == False - if not infer_state.is_splitfuse and infer_state.is_prefill: + + def _wquant_matmul_triton_int4weight_only_quant( + self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False + ): + assert has_act is False + if infer_state.is_splitfuse is False and infer_state.is_prefill: qweight, scale, zeros, int4_q_group_size = quant_weight_params out = matmul_dequantize_int4_s1(input, qweight, scale, zeros, int4_q_group_size, out=out) else: @@ -148,23 +159,27 @@ def _wquant_matmul_triton_int4weight_only_quant(self, input, quant_weight_params else: out.add_(bias) return out - - def _wquant_matmul_lmdeploy_int4weight_only_quant(self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False): - assert has_act == False + + def _wquant_matmul_lmdeploy_int4weight_only_quant( + self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False + ): + assert has_act is False qweight, scale_zeros, int4_q_group_size = quant_weight_params - out = matmul_dequantize_int4_lmdeploy(input, qweight, scale_zeros, int4_q_group_size) + out = matmul_dequantize_int4_lmdeploy(input, qweight, scale_zeros, int4_q_group_size) if bias is None: return out else: out.add_(bias) return out - def _wquant_matmul_ppl_int4weight_only_quant(self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False): - assert has_act == False + def _wquant_matmul_ppl_int4weight_only_quant( + self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False + ): + assert has_act is False qweight, qscale = quant_weight_params out = matmul_dequantize_int4_ppl(input, qweight, qscale) if bias is None: return out else: out.add_(bias) - return out \ No newline at end of file + return out diff --git a/lightllm/models/llama_wquant/layer_weights/transformer_layer_weight.py b/lightllm/models/llama_wquant/layer_weights/transformer_layer_weight.py index b4ff95666..e24ff8834 100644 --- a/lightllm/models/llama_wquant/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama_wquant/layer_weights/transformer_layer_weight.py @@ -15,26 +15,28 @@ class LlamaTransformerLayerWeightQuantized(TransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) self.init_quant_mode() - + def init_quant_mode(self): if "triton_int8weight" in self.mode: self.quantize_weight = partial(quantize_int8, tp_rank=self.tp_rank_) if "triton_int4weight" in self.mode: self.int4_q_group_size = 128 for _mode in self.mode: - if _mode.startswith('g'): + if _mode.startswith("g"): self.int4_q_group_size = int(_mode[1:]) self.quantize_weight = partial(quantize_int4, group_size=self.int4_q_group_size, tp_rank=self.tp_rank_) if "lmdeploy_int4weight" in self.mode: self.int4_q_group_size = 128 for _mode in self.mode: - if _mode.startswith('g'): + if _mode.startswith("g"): self.int4_q_group_size = int(_mode[1:]) - self.quantize_weight = partial(quantize_int4_lmdeploy, group_size=self.int4_q_group_size, tp_rank=self.tp_rank_) + self.quantize_weight = partial( + quantize_int4_lmdeploy, group_size=self.int4_q_group_size, tp_rank=self.tp_rank_ + ) if "ppl_int4weight" in self.mode: self.int4_q_group_size = 128 for _mode in self.mode: - if _mode.startswith('g'): + if _mode.startswith("g"): self.int4_q_group_size = int(_mode[1:]) self.quantize_weight = partial(quantize_int4_ppl, group_size=self.int4_q_group_size, tp_rank=self.tp_rank_) @@ -46,12 +48,13 @@ def verify_load(self): errors = "weights load not ok" weights = [ self.att_norm_weight_, - self.qkv_weight_, + self.q_weight_, + self.kv_weight_, self.o_weight_, self.ffn_norm_weight_, self.gate_up_proj, - self.down_proj - ] + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors @@ -62,79 +65,67 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ - - if getattr(self, "qkv_weight_", None) is None: - self.qkv_weight_ = torch.empty(n_embed, q_split_n_embed + 2 * kv_split_n_embed, dtype=self.data_type_, device='cpu') - self.qkv_step_ = 0 - + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) + # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] - q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] + q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] q_weight_ = q_weight_.transpose(0, 1).to(self.data_type_) - self.qkv_weight_[:, :q_split_n_embed] = q_weight_ - self.qkv_step_ += 1 - + self.q_weight_ = self.quantize_weight(q_weight_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_) - self.qkv_weight_[:, q_split_n_embed: (q_split_n_embed + kv_split_n_embed)] = k_weight_ - self.qkv_step_ += 1 + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_) if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_) - self.qkv_weight_[:, (q_split_n_embed + kv_split_n_embed):(q_split_n_embed + 2 * kv_split_n_embed)] = v_weight_ - self.qkv_step_ += 1 - - if self.qkv_step_ == 3: - self.qkv_step_ = 0 - self.qkv_weight_ = self.quantize_weight(self.qkv_weight_) + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1, handle_func=self.quantize_weight) # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] - self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self.quantize_weight(self.o_weight_.transpose(0, 1)) - + return - + def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: - self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]) - - n_embed = self.network_config_["hidden_size"] - inter_size = self.network_config_['intermediate_size'] - split_inter_size = inter_size // self.world_size_ + self.ffn_norm_weight_ = self._cuda( + weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] + ) - if getattr(self, "gate_up_proj", None) is None: - self.gate_up_proj = torch.empty(n_embed, split_inter_size * 2, dtype=self.data_type_, device='cpu') - self.gate_up_step = 0 + # n_embed = self.network_config_["hidden_size"] + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - gate_proj = gate_proj.transpose(0, 1).to(self.data_type_) - self.gate_up_proj[:, : split_inter_size] = gate_proj - self.gate_up_step += 1 + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1).to(self.data_type_) if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - up_proj = up_proj.transpose(0, 1).to(self.data_type_) - self.gate_up_proj[:, split_inter_size : split_inter_size * 2] = up_proj - self.gate_up_step += 1 + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1).to(self.data_type_) - if self.gate_up_step == 2: - self.gate_up_step = 0 - self.gate_up_proj = self.quantize_weight(self.gate_up_proj) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1, handle_func=self.quantize_weight) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: - self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][:, - split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.down_proj = self.quantize_weight(self.down_proj.transpose(0, 1)) return diff --git a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 9d9997c34..4eb2ffa7b --- a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py @@ -29,11 +29,11 @@ def _bind_func(self): self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal return - def _context_attention_kernel(self, q, k, v, infer_state:MistralInferStateInfo, layer_weight, out=None)->torch.Tensor: + def _context_attention_kernel(self, q, kv, infer_state:MistralInferStateInfo, layer_weight, out=None)->torch.Tensor: o_tensor = torch.empty_like(q) if out is None else out context_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), - k.view(-1, self.tp_k_head_num_, self.head_dim_), - v.view(-1, self.tp_v_head_num_, self.head_dim_), + kv[:, 0: self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.b_start_loc, infer_state.b_seq_len, @@ -49,7 +49,7 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda") token_att_fwd(q.view(calcu_shape1), - infer_state.mem_manager.key_buffer[self.layer_num_], + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :], att_m_tensor, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -67,7 +67,7 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, token_softmax_fwd(att_m_tensor, infer_state.b_att_start_loc, infer_state.b_att_seq_len, prob, infer_state.sliding_window) att_m_tensor = None token_att_fwd2(prob, - infer_state.mem_manager.value_buffer[self.layer_num_], + infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :], o_tensor.view(calcu_shape1), infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, @@ -81,7 +81,7 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, elif triton.__version__ >= "2.1.0": from lightllm.models.mistral.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd token_softmax_reducev_fwd(att_m_tensor, - infer_state.mem_manager.value_buffer[self.layer_num_], + infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :], o_tensor.view(calcu_shape1), infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index aa6711d0a..14ebd6e98 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -6,28 +6,41 @@ logger = init_logger(__name__) + class MixtralTransformerLayerWeight(TransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) - self.experts = [{"w1": None, "w2": None, "w3": None} for _ in range(self.network_config_['num_local_experts'])] - return + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) + self.experts = [{"w1": None, "w2": None, "w3": None} for _ in range(self.network_config_["num_local_experts"])] + return def load_hf_weights(self, weights): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) return - + def verify_load(self): errors = "weights load not ok" - weights = [self.att_norm_weight_, self.q_weight_, self.k_weight_, self.v_weight_, - self.o_weight_, self.ffn_norm_weight_, self.gate] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.ffn_norm_weight_, + self.gate, + ] for i in range(len(weights)): assert weights[i] is not None, str(i) + " " + str(self.layer_num_) + " " + errors for i in range(self.network_config_["num_local_experts"]): - assert self.experts[i]["w1"] is not None, "layer " + str(self.layer_num_) + " expert " + str(i) + " w1 " + errors - assert self.experts[i]["w2"] is not None, "layer " + str(self.layer_num_) + " expert " + str(i) + " w2 " + errors - assert self.experts[i]["w3"] is not None, "layer " + str(self.layer_num_) + " expert " + str(i) + " w3 " + errors - return + assert self.experts[i]["w1"] is not None, ( + "layer " + str(self.layer_num_) + " expert " + str(i) + " w1 " + errors + ) + assert self.experts[i]["w2"] is not None, ( + "layer " + str(self.layer_num_) + " expert " + str(i) + " w2 " + errors + ) + assert self.experts[i]["w3"] is not None, ( + "layer " + str(self.layer_num_) + " expert " + str(i) + " w3 " + errors + ) + return def _load_qkvo_weights(self, weights): # input layernorm params @@ -36,53 +49,65 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] - self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: - self.k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - self.k_weight_ = self.k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = self._cuda(self.k_weight_.transpose(0, 1)) + k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: - self.v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - self.v_weight_ = self.v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = self._cuda(self.v_weight_.transpose(0, 1)) - + v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] - self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) return - + def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: - self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]) - - inter_size = self.network_config_['intermediate_size'] + self.ffn_norm_weight_ = self._cuda( + weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] + ) + + inter_size = self.network_config_["intermediate_size"] split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.block_sparse_moe.gate.weight" in weights: self.gate = weights[f"model.layers.{self.layer_num_}.block_sparse_moe.gate.weight"] self.gate = self._cuda(self.gate.transpose(0, 1)) - for expert_idx in range(self.network_config_['num_local_experts']): + for expert_idx in range(self.network_config_["num_local_experts"]): if f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w1.weight" in weights: - self.experts[expert_idx]["w1"] = weights[f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w1.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] + self.experts[expert_idx]["w1"] = weights[ + f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w1.weight" + ][split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] self.experts[expert_idx]["w1"] = self._cuda(self.experts[expert_idx]["w1"].transpose(0, 1)) if f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w2.weight" in weights: - self.experts[expert_idx]["w2"] = weights[f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w2.weight"][:, split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.experts[expert_idx]["w2"] = weights[ + f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w2.weight" + ][:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)] self.experts[expert_idx]["w2"] = self._cuda(self.experts[expert_idx]["w2"].transpose(0, 1)) - + if f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w3.weight" in weights: - self.experts[expert_idx]["w3"] = weights[f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w3.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] + self.experts[expert_idx]["w3"] = weights[ + f"model.layers.{self.layer_num_}.block_sparse_moe.experts.{expert_idx}.w3.weight" + ][split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] self.experts[expert_idx]["w3"] = self._cuda(self.experts[expert_idx]["w3"].transpose(0, 1)) return diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 566bef291..a4c5313b9 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -5,26 +5,35 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.models.qwen.layer_weights.transformer_layer_weight import QwenTransformerLayerWeight from lightllm.models.qwen.infer_struct import QwenInferStateInfo + class QwenTransformerLayerInfer(LlamaTransformerLayerInfer): - """ - """ + """ """ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) return - - def _get_qkv(self, input_emb, cache_k, cache_v, infer_state: QwenInferStateInfo, layer_weight:LlamaTransformerLayerWeight): - q = torch.addmm(layer_weight.q_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0) - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) + + def _get_qkv(self, input_emb, cache_kv, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight): + q = torch.addmm( + layer_weight.q_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + ) + torch.addmm( + layer_weight.kv_bias_, + input_emb.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + beta=1.0, + alpha=1.0, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) if infer_state.logn_values is not None: q.mul_(infer_state.logn_values.view(-1, 1)) - torch.addmm(layer_weight.k_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.k_weight_, beta=1.0, alpha=1.0, - out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_)) - rotary_emb_fwd(cache_k, infer_state.position_cos, infer_state.position_sin) - torch.addmm(layer_weight.v_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.v_weight_, beta=1.0, alpha=1.0, - out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_)) - return q, cache_k, cache_v - + return q, cache_kv diff --git a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py old mode 100644 new mode 100755 index bd285fec8..039d15160 --- a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py @@ -21,65 +21,73 @@ def load_hf_weights(self, weights): split_size = qkv_weights.shape[0] // 3 q_weights, k_weights, v_weights = torch.split(qkv_weights, split_size, dim=0) - self.q_weight_ = q_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = q_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) - self.k_weight_ = k_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = self._cuda(self.k_weight_.transpose(0, 1)) - self.v_weight_ = v_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = self._cuda(self.v_weight_.transpose(0, 1)) - + k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) + v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + if f"transformer.h.{self.layer_num_}.attn.c_attn.bias" in weights: qkv_bias = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.bias"] split_size = qkv_bias.shape[0] // 3 q_bias, k_bias, v_bias = torch.split(qkv_bias, split_size, dim=0) - self.q_bias_ = self._cuda(q_bias[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)]) - self.k_bias_ = self._cuda(k_bias[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)]) - self.v_bias_ = self._cuda(v_bias[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)]) + self.q_bias_ = self._cuda(q_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)]) + self.k_bias_ = k_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + self.v_bias_ = v_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) # attention output dense params if f"transformer.h.{self.layer_num_}.attn.c_proj.weight" in weights: - self.o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"][:, - split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) if f"transformer.h.{self.layer_num_}.ln_2.weight" in weights: self.ffn_norm_weight_ = self._cuda(weights[f"transformer.h.{self.layer_num_}.ln_2.weight"]) # ffn params - inter_size = self.network_config_['intermediate_size'] // 2 + inter_size = self.network_config_["intermediate_size"] // 2 split_inter_size = inter_size // self.world_size_ if f"transformer.h.{self.layer_num_}.mlp.w1.weight" in weights: - self.up_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w1.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.up_proj = self._cuda(self.up_proj.transpose(0, 1)) + up_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w1.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) if f"transformer.h.{self.layer_num_}.mlp.w2.weight" in weights: - self.gate_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w2.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.gate_proj = self._cuda(self.gate_proj.transpose(0, 1)) + gate_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w2.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) if f"transformer.h.{self.layer_num_}.mlp.c_proj.weight" in weights: - self.down_proj = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"][:, - split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.down_proj = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) - + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + return - + def verify_load(self): errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_weight_, - self.ffn_norm_weight_, - self.up_proj, - self.gate_proj, - self.down_proj - ] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.q_bias_, + self.kv_bias_, + self.o_weight_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return + return diff --git a/lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index 083da8889..4060d1b2f --- a/lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py @@ -8,32 +8,36 @@ from lightllm.models.qwen_wquant.layer_weights.transformer_layer_weight import QwenTransformerLayerWeightQuantized from lightllm.models.qwen.infer_struct import QwenInferStateInfo + class QwenTransformerLayerInferWQuant(LlamaTransformerLayerInferWquant): - """ - """ + """ """ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.inter_dim_ = network_config['intermediate_size'] // 2 # qwen 的 inter_dim 要 // 2 + self.inter_dim_ = network_config["intermediate_size"] // 2 # qwen 的 inter_dim 要 // 2 return - - def _get_qkv(self, input, cache_k, cache_v, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeightQuantized): - qkv_output = self._wquant_matmul_for_qkv(input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.qkv_weight_, - infer_state=infer_state, - bias=layer_weight.qkv_bias_) - - tp_k_head_dim = self.tp_k_head_num_ * self.head_dim_ - q = qkv_output[:, : -2 * tp_k_head_dim] - k = qkv_output[:, -2 * tp_k_head_dim: -tp_k_head_dim] - v = qkv_output[:, -tp_k_head_dim :] + def _get_qkv( + self, input, cache_kv, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeightQuantized + ): + q = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + quant_weight_params=layer_weight.q_weight_, + infer_state=infer_state, + bias=layer_weight.q_bias_, + ) + cache_kv = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + quant_weight_params=layer_weight.kv_weight_, + infer_state=infer_state, + bias=layer_weight.kv_bias_, + ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) if infer_state.logn_values is not None: q.mul_(infer_state.logn_values.view(-1, 1)) - - rotary_emb_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_), infer_state.position_cos, infer_state.position_sin) - cache_k_ = k.view(-1, self.tp_k_head_num_, self.head_dim_) - rotary_emb_fwd(cache_k_, infer_state.position_cos, infer_state.position_sin) - cache_v_ = v.view(-1, self.tp_v_head_num_, self.head_dim_) - return q, cache_k_, cache_v_ - + return q, cache_kv diff --git a/lightllm/models/qwen_wquant/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen_wquant/layer_weights/transformer_layer_weight.py index b43a48b42..9a7cf3c2c 100644 --- a/lightllm/models/qwen_wquant/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen_wquant/layer_weights/transformer_layer_weight.py @@ -5,12 +5,13 @@ from lightllm.common.basemodel import TransformerLayerWeight from lightllm.models.llama_wquant.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeightQuantized + class QwenTransformerLayerWeightQuantized(TransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) LlamaTransformerLayerWeightQuantized.init_quant_mode(self) return - + def load_hf_weights(self, weights): # input layernorm params if f"transformer.h.{self.layer_num_}.ln_1.weight" in weights: @@ -24,72 +25,77 @@ def load_hf_weights(self, weights): split_size = qkv_weights.shape[0] // 3 q_weights, k_weights, v_weights = torch.split(qkv_weights, split_size, dim=0) - q_weight_ = q_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] - k_weight_ = k_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] - v_weight_ = v_weights[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = q_weights[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ].transpose(0, 1) + self.k_weight_ = k_weights[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ].transpose(0, 1) + self.v_weight_ = v_weights[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ].transpose(0, 1) + + self.q_weight_ = self.quantize_weight(self.q_weight_) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1, handle_func=self.quantize_weight) - qkv_weight_ = torch.cat([q_weight_, k_weight_, v_weight_], dim=0) - self.qkv_weight_ = self.quantize_weight(qkv_weight_.transpose(0, 1)) - if f"transformer.h.{self.layer_num_}.attn.c_attn.bias" in weights: qkv_bias = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.bias"] split_size = qkv_bias.shape[0] // 3 q_bias, k_bias, v_bias = torch.split(qkv_bias, split_size, dim=0) - q_bias_ = q_bias[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] - k_bias_ = k_bias[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] - v_bias_ = v_bias[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] - self.qkv_bias_ = self._cuda(torch.cat([q_bias_, k_bias_, v_bias_], dim=0).view(-1)) + self.q_bias_ = self._cuda(q_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)]) + self.k_bias_ = k_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + self.v_bias_ = v_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) # attention output dense params if f"transformer.h.{self.layer_num_}.attn.c_proj.weight" in weights: - self.o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"][:, split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] self.o_weight_ = self.quantize_weight(self.o_weight_.transpose(0, 1)) if f"transformer.h.{self.layer_num_}.ln_2.weight" in weights: self.ffn_norm_weight_ = self._cuda(weights[f"transformer.h.{self.layer_num_}.ln_2.weight"]) # ffn params - inter_size = self.network_config_['intermediate_size'] // 2 + inter_size = self.network_config_["intermediate_size"] // 2 split_inter_size = inter_size // self.world_size_ - if getattr(self, "gate_up_proj", None) is None: - self.gate_up_proj = torch.empty(n_embed, split_inter_size * 2, dtype=self.data_type_, device='cpu') - self.gate_up_step = 0 - if f"transformer.h.{self.layer_num_}.mlp.w2.weight" in weights: gate_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w2.weight"] - gate_proj = gate_proj[split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - gate_proj = gate_proj.transpose(0, 1).to(self.data_type_) - self.gate_up_proj[:, : split_inter_size] = gate_proj - self.gate_up_step += 1 + gate_proj = gate_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] + self.gate_proj = gate_proj.transpose(0, 1).to(self.data_type_) if f"transformer.h.{self.layer_num_}.mlp.w1.weight" in weights: up_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w1.weight"] - up_proj = up_proj[split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - up_proj = up_proj.transpose(0, 1).to(self.data_type_) - self.gate_up_proj[:, split_inter_size : split_inter_size * 2] = up_proj - self.gate_up_step += 1 - - if self.gate_up_step == 2: - self.gate_up_step = 0 - self.gate_up_proj = self.quantize_weight(self.gate_up_proj) + up_proj = up_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] + self.up_proj = up_proj.transpose(0, 1).to(self.data_type_) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1, handle_func=self.quantize_weight) if f"transformer.h.{self.layer_num_}.mlp.c_proj.weight" in weights: self.down_proj = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"] - self.down_proj = self.down_proj[:,split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.down_proj = self.down_proj[ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.down_proj = self.quantize_weight(self.down_proj.transpose(0, 1)) - + return - + def verify_load(self): errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.qkv_weight_, - self.qkv_bias_, - self.o_weight_, - self.ffn_norm_weight_, - self.gate_up_proj, - self.down_proj - ] + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.q_bias_, + self.kv_bias_, + self.o_weight_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors - return + return diff --git a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py index 3a23db639..913e3e6cb 100644 --- a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py @@ -8,7 +8,7 @@ class StarcoderTransformerLayerWeight(BloomTransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) assert network_config["num_attention_heads"] % self.world_size_ == 0 - + def init_static_params(self): pass @@ -25,27 +25,38 @@ def _load_qkvo_weights(self, weights): head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] split_n_embed = n_embed // self.world_size_ if f"transformer.h.{self.layer_num_}.attn.c_attn.weight" in weights: - self.qkv_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.weight"].transpose(0, 1).contiguous().to(self.data_type_) - self.q_weight_ = self.qkv_weight_[:, :n_embed][:, split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + qkv_weight_ = ( + weights[f"transformer.h.{self.layer_num_}.attn.c_attn.weight"] + .transpose(0, 1) + .contiguous() + .to(self.data_type_) + ) + self.q_weight_ = qkv_weight_[:, :n_embed][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] self.q_weight_ = self.q_weight_.cuda() - self.k_weight_ = self.qkv_weight_[:, n_embed:n_embed + head_dim] - self.k_weight_ = self.k_weight_.cuda() + self.k_weight_ = qkv_weight_[:, n_embed : n_embed + head_dim] + self.v_weight_ = qkv_weight_[:, n_embed + head_dim : n_embed + 2 * head_dim] - self.v_weight_ = self.qkv_weight_[:, n_embed + head_dim:n_embed + 2 * head_dim] - self.v_weight_ = self.v_weight_.cuda() + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) if f"transformer.h.{self.layer_num_}.attn.c_attn.bias" in weights: - self.qkv_bias_ = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.bias"].to(self.data_type_) - self.q_bias_ = self.qkv_bias_[:n_embed].cuda()[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] - self.k_bias_ = self.qkv_bias_[n_embed : n_embed + head_dim].cuda() - self.v_bias_ = self.qkv_bias_[n_embed + head_dim : n_embed + 2 * head_dim].cuda() + qkv_bias_ = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.bias"].to(self.data_type_) + self.q_bias_ = self._cuda( + qkv_bias_[:n_embed][split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + ) + self.k_bias_ = qkv_bias_[n_embed : n_embed + head_dim] + self.v_bias_ = qkv_bias_[n_embed + head_dim : n_embed + 2 * head_dim] + + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) # attention output dense params if f"transformer.h.{self.layer_num_}.attn.c_proj.weight" in weights: - self.o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"][:, - split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] self.o_weight_ = self.o_weight_.transpose(0, 1).contiguous().to(self.data_type_) self.o_weight_ = self.o_weight_.cuda() @@ -55,11 +66,9 @@ def _load_qkvo_weights(self, weights): def _load_ffn_weights(self, weights): if f"transformer.h.{self.layer_num_}.ln_2.weight" in weights: - self.ffn_norm_weight_ = weights[f"transformer.h.{self.layer_num_}.ln_2.weight"].to( - self.data_type_).cuda() + self.ffn_norm_weight_ = weights[f"transformer.h.{self.layer_num_}.ln_2.weight"].to(self.data_type_).cuda() if f"transformer.h.{self.layer_num_}.ln_2.bias" in weights: - self.ffn_norm_bias_ = weights[f"transformer.h.{self.layer_num_}.ln_2.bias"].to( - self.data_type_).cuda() + self.ffn_norm_bias_ = weights[f"transformer.h.{self.layer_num_}.ln_2.bias"].to(self.data_type_).cuda() # ffn params n_embed = self.network_config_["hidden_size"] @@ -67,19 +76,35 @@ def _load_ffn_weights(self, weights): split_inter_size = intermediate_size // self.world_size_ if f"transformer.h.{self.layer_num_}.mlp.c_fc.weight" in weights: self.ffn_1_weight_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_fc.weight"].to(self.data_type_) - self.ffn_1_weight_ = self.ffn_1_weight_[split_inter_size * self.tp_rank_: split_inter_size * - (self.tp_rank_ + 1), :].transpose(0, 1).contiguous().cuda() + self.ffn_1_weight_ = ( + self.ffn_1_weight_[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] + .transpose(0, 1) + .contiguous() + .cuda() + ) if f"transformer.h.{self.layer_num_}.mlp.c_fc.bias" in weights: - self.ffn_1_bias_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_fc.bias"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)].to(self.data_type_).contiguous().cuda() + self.ffn_1_bias_ = ( + weights[f"transformer.h.{self.layer_num_}.mlp.c_fc.bias"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] + .to(self.data_type_) + .contiguous() + .cuda() + ) if f"transformer.h.{self.layer_num_}.mlp.c_proj.weight" in weights: self.ffn_2_weight_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"].to(self.data_type_) - self.ffn_2_weight_ = self.ffn_2_weight_[:, split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)].transpose(0, 1).contiguous().cuda() + self.ffn_2_weight_ = ( + self.ffn_2_weight_[:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)] + .transpose(0, 1) + .contiguous() + .cuda() + ) if f"transformer.h.{self.layer_num_}.mlp.c_proj.bias" in weights: - self.ffn_2_bias_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.bias"].to(self.data_type_).contiguous().cuda() + self.ffn_2_bias_ = ( + weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.bias"].to(self.data_type_).contiguous().cuda() + ) return diff --git a/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py old mode 100644 new mode 100755 index c7fe06a17..745875f79 --- a/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py @@ -8,9 +8,14 @@ from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import matmul_quantize_int8 from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int8 import matmul_dequantize_int8 -from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import matmul_dequantize_int4_s1, matmul_dequantize_int4_s2, matmul_dequantize_int4_gptq -from lightllm.models.starcoder_wquant.layer_weights.transformer_layer_weight import \ - StarcoderTransformerLayerWeightQuantized +from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import ( + matmul_dequantize_int4_s1, + matmul_dequantize_int4_s2, + matmul_dequantize_int4_gptq, +) +from lightllm.models.starcoder_wquant.layer_weights.transformer_layer_weight import ( + StarcoderTransformerLayerWeightQuantized, +) from lightllm.utils.infer_utils import mark_cost_time from lightllm.models.starcoder.infer_struct import StarcoderInferStateInfo from lightllm.common.basemodel import TransformerLayerInferWeightQuantTpl @@ -20,8 +25,8 @@ class StarcoderTransformerLayerInferWQuant(TransformerLayerInferWeightQuantTpl): - """ - """ + """ """ + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) self.eps_ = network_config["layer_norm_epsilon"] @@ -33,7 +38,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): self.embed_dim_ = network_config["n_embed"] self._bind_func() return - + def _bind_func(self): self._att_norm = partial(BloomTransformerLayerInfer._att_norm, self) self._ffn_norm = partial(BloomTransformerLayerInfer._ffn_norm, self) @@ -42,38 +47,46 @@ def _bind_func(self): LlamaTransformerLayerInferWquant._bind_attention(self) return - def _get_qkv(self, input, cache_k, cache_v, infer_state: StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized) -> torch.Tensor: - qkv_output = self._wquant_matmul_for_qkv(input.view(-1, self.embed_dim_), - layer_weight.qkv_weight_, - infer_state=infer_state, - bias=layer_weight.qkv_bias_) - tp_k_head_dim = self.tp_k_head_num_ * self.head_dim_ - q = qkv_output[:, : -2 * tp_k_head_dim] - k = qkv_output[:, -2 * tp_k_head_dim: -tp_k_head_dim] - v = qkv_output[:, -tp_k_head_dim :] - - cache_k = k.view(-1, self.tp_k_head_num_, self.head_dim_) - cache_v = v.view(-1, self.tp_v_head_num_, self.head_dim_) - return q, cache_k, cache_v - - def _get_o(self, input, infer_state: StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized) -> torch.Tensor: - o_output =self._wquant_matmul_for_o(input.view(-1, self.embed_dim_), - layer_weight.o_weight_, - infer_state=infer_state, - bias=layer_weight.o_bias_) + def _get_qkv( + self, + input, + cache_kv, + infer_state: StarcoderInferStateInfo, + layer_weight: StarcoderTransformerLayerWeightQuantized, + ) -> torch.Tensor: + q = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), layer_weight.q_weight_, infer_state=infer_state, bias=layer_weight.q_bias_ + ) + cache_kv = self._wquant_matmul_for_qkv( + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + infer_state=infer_state, + bias=layer_weight.kv_bias_, + ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) + return q, cache_kv + + def _get_o( + self, input, infer_state: StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized + ) -> torch.Tensor: + o_output = self._wquant_matmul_for_o( + input.view(-1, self.embed_dim_), layer_weight.o_weight_, infer_state=infer_state, bias=layer_weight.o_bias_ + ) return o_output - def _ffn(self, input, infer_state:StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized)->torch.Tensor: - ffn1_out = self._wquant_matmul_for_ffn_up(input.view(-1, self.embed_dim_), - layer_weight.ffn_1_weight_, - infer_state=infer_state, - bias=layer_weight.ffn_1_bias_) + def _ffn( + self, input, infer_state: StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized + ) -> torch.Tensor: + ffn1_out = self._wquant_matmul_for_ffn_up( + input.view(-1, self.embed_dim_), + layer_weight.ffn_1_weight_, + infer_state=infer_state, + bias=layer_weight.ffn_1_bias_, + ) input = None - gelu_out = torch.nn.functional.gelu(ffn1_out, approximate='tanh') + gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") ffn1_out = None - ffn2_out = self._wquant_matmul_for_ffn_down(gelu_out, - layer_weight.ffn_2_weight_, - infer_state=infer_state, - bias=layer_weight.ffn_2_bias_) + ffn2_out = self._wquant_matmul_for_ffn_down( + gelu_out, layer_weight.ffn_2_weight_, infer_state=infer_state, bias=layer_weight.ffn_2_bias_ + ) gelu_out = None - return ffn2_out \ No newline at end of file + return ffn2_out diff --git a/lightllm/models/starcoder_wquant/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder_wquant/layer_weights/transformer_layer_weight.py index 3fffbf378..32aa5e46e 100644 --- a/lightllm/models/starcoder_wquant/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder_wquant/layer_weights/transformer_layer_weight.py @@ -7,33 +7,37 @@ from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import quantize_int4 from lightllm.common.basemodel import TransformerLayerWeight from lightllm.models.llama_wquant.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeightQuantized + + class StarcoderTransformerLayerWeightQuantized(TransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) assert network_config["num_attention_heads"] % self.world_size_ == 0 LlamaTransformerLayerWeightQuantized.init_quant_mode(self) return - + def load_hf_weights(self, weights): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) def verify_load(self): errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.att_norm_bias_, - self.qkv_weight_, - self.qkv_bias_, - self.o_weight_, - self.o_bias_, - - self.ffn_norm_weight_, - self.ffn_norm_bias_, - self.ffn_1_weight_, - self.ffn_1_bias_, - self.ffn_2_weight_, - self.ffn_2_bias_, - ] + weights = [ + self.att_norm_weight_, + self.att_norm_bias_, + self.q_weight_, + self.kv_weight_, + self.q_bias_, + self.kv_bias_, + self.o_weight_, + self.o_bias_, + self.ffn_norm_weight_, + self.ffn_norm_bias_, + self.ffn_1_weight_, + self.ffn_1_bias_, + self.ffn_2_weight_, + self.ffn_2_bias_, + ] for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors return @@ -51,26 +55,35 @@ def _load_qkvo_weights(self, weights): head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] split_n_embed = n_embed // self.world_size_ if f"transformer.h.{self.layer_num_}.attn.c_attn.weight" in weights: - qkv_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.weight"].transpose(0, 1).contiguous().to(self.data_type_) - q_weight_ = qkv_weight_[:, :n_embed][:, split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] - k_weight_ = qkv_weight_[:, n_embed:n_embed + head_dim] - v_weight_ = qkv_weight_[:, n_embed + head_dim:n_embed + 2 * head_dim] - qkv_fused_weight = torch.cat((q_weight_, k_weight_, v_weight_), dim=1) - - self.qkv_weight_ = self.quantize_weight(qkv_fused_weight) + qkv_weight_ = ( + weights[f"transformer.h.{self.layer_num_}.attn.c_attn.weight"] + .transpose(0, 1) + .contiguous() + .to(self.data_type_) + ) + self.q_weight_ = qkv_weight_[:, :n_embed][ + :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) + ] + self.k_weight_ = qkv_weight_[:, n_embed : n_embed + head_dim] + self.v_weight_ = qkv_weight_[:, n_embed + head_dim : n_embed + 2 * head_dim] + self.q_weight_ = self.quantize_weight(self.q_weight_) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1, handle_func=self.quantize_weight) if f"transformer.h.{self.layer_num_}.attn.c_attn.bias" in weights: qkv_bias_ = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.bias"].to(self.data_type_) - q_bias_ = qkv_bias_[:n_embed][split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] - k_bias_ = qkv_bias_[n_embed : n_embed + head_dim] - v_bias_ = qkv_bias_[n_embed + head_dim : n_embed + 2 * head_dim] - - self.qkv_bias_ = self._cuda(torch.cat((q_bias_, k_bias_, v_bias_), dim=0)) + self.q_bias_ = qkv_bias_[:n_embed][split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + self.k_bias_ = qkv_bias_[n_embed : n_embed + head_dim] + self.v_bias_ = qkv_bias_[n_embed + head_dim : n_embed + 2 * head_dim] + + self.q_bias_ = self._cuda(self.q_bias_) + + self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) # attention output dense params if f"transformer.h.{self.layer_num_}.attn.c_proj.weight" in weights: o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"] - o_weight_ = o_weight_[:,split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + o_weight_ = o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] o_weight_ = o_weight_.transpose(0, 1).contiguous().to(self.data_type_) self.o_weight_ = self.quantize_weight(o_weight_) @@ -91,19 +104,29 @@ def _load_ffn_weights(self, weights): split_inter_size = intermediate_size // self.world_size_ if f"transformer.h.{self.layer_num_}.mlp.c_fc.weight" in weights: self.ffn_1_weight_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_fc.weight"].to(self.data_type_) - self.ffn_1_weight_ = self.ffn_1_weight_[split_inter_size * self.tp_rank_: split_inter_size * - (self.tp_rank_ + 1), :].transpose(0, 1).contiguous().to(self.data_type_) + self.ffn_1_weight_ = ( + self.ffn_1_weight_[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] + .transpose(0, 1) + .contiguous() + .to(self.data_type_) + ) self.ffn_1_weight_ = self.quantize_weight(self.ffn_1_weight_) if f"transformer.h.{self.layer_num_}.mlp.c_fc.bias" in weights: self.ffn_1_bias_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_fc.bias"] - self.ffn_1_bias_ = self.ffn_1_bias_[split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.ffn_1_bias_ = self.ffn_1_bias_[ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.ffn_1_bias_ = self._cuda(self.ffn_1_bias_) if f"transformer.h.{self.layer_num_}.mlp.c_proj.weight" in weights: self.ffn_2_weight_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"].to(self.data_type_) - self.ffn_2_weight_ = self.ffn_2_weight_[:, split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)].transpose(0, 1).contiguous().to(self.data_type_) + self.ffn_2_weight_ = ( + self.ffn_2_weight_[:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)] + .transpose(0, 1) + .contiguous() + .to(self.data_type_) + ) self.ffn_2_weight_ = self.quantize_weight(self.ffn_2_weight_) if f"transformer.h.{self.layer_num_}.mlp.c_proj.bias" in weights: diff --git a/lightllm/models/yi/layer_weights/transformer_layer_weight.py b/lightllm/models/yi/layer_weights/transformer_layer_weight.py index f8b78481c..c22c19977 100644 --- a/lightllm/models/yi/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/yi/layer_weights/transformer_layer_weight.py @@ -4,11 +4,12 @@ from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + class YiTransformerLayerWeight(LlamaTransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) return - + def _load_qkvo_weights(self, weights): # input layernorm params if f"model.layers.{self.layer_num_}.ln1.weight" in weights: @@ -16,49 +17,61 @@ def _load_qkvo_weights(self, weights): n_embed = self.network_config_["hidden_size"] q_split_n_embed = n_embed // self.world_size_ - kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] - self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: - self.k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - self.k_weight_ = self.k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = self._cuda(self.k_weight_.transpose(0, 1)) + k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: - self.v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - self.v_weight_ = self.v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = self._cuda(self.v_weight_.transpose(0, 1)) - + v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] - self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) return - + def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.ln2.weight" in weights: self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.ln2.weight"]) - - inter_size = self.network_config_['intermediate_size'] + + inter_size = self.network_config_["intermediate_size"] split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - self.up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.up_proj = self._cuda(self.up_proj.transpose(0, 1)) + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - self.gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][split_inter_size * - self.tp_rank_: split_inter_size * (self.tp_rank_ + 1), :] - self.gate_proj = self._cuda(self.gate_proj.transpose(0, 1)) + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: - self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][:, - split_inter_size * self.tp_rank_: split_inter_size * (self.tp_rank_ + 1)] + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) - return \ No newline at end of file + return diff --git a/test/model/test_settings/process_utils.py b/test/model/test_settings/process_utils.py index 776f6995e..352e6f03d 100644 --- a/test/model/test_settings/process_utils.py +++ b/test/model/test_settings/process_utils.py @@ -1,28 +1,33 @@ import subprocess import re + def kill_gpu_processes(): try: - output = subprocess.check_output(['nvidia-smi', '-q', '-x']) - output = output.decode('utf-8') - + output = subprocess.check_output(["nvidia-smi", "-q", "-x"]) + output = output.decode("utf-8") + # 使用正则表达式提取进程信息 - process_info = re.findall(r'(.*?)', output, re.DOTALL) - + process_info = re.findall(r"(.*?)", output, re.DOTALL) + if process_info: print("找到以下占用显卡的进程:") for info in process_info: - pid = re.search(r'(.*?)', info).group(1) - process_name = re.search(r'(.*?)', info).group(1) + pid = re.search(r"(.*?)", info).group(1) + process_name = re.search(r"(.*?)", info).group(1) print("进程ID:", pid) print("进程名字:", process_name) - + for info in process_info: - pid = re.search(r'(.*?)', info).group(1) - subprocess.call(['sudo', 'kill', '-9', pid]) + pid = re.search(r"(.*?)", info).group(1) + subprocess.call(["sudo", "kill", "-9", pid]) print("进程ID", pid, "被终止") else: print("没有找到占用显卡的进程") - + except subprocess.CalledProcessError: - print("无法执行nvidia-smi命令") \ No newline at end of file + print("无法执行nvidia-smi命令") + + +if __name__ == "__main__": + kill_gpu_processes()