-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: hiworldwzj <[email protected]>
- Loading branch information
1 parent
f7b9937
commit 414aa08
Showing
9 changed files
with
525 additions
and
58 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
import numpy as np | ||
from lightllm.common.basemodel import InferStateInfo | ||
from lightllm.common.req_manager import ReqManager | ||
|
||
|
||
class Qwen2InferStateInfo(InferStateInfo): | ||
def __init__(self): | ||
super().__init__() | ||
self.sliding_window = None | ||
self.b_start_loc_window = None | ||
self.b_att_seq_len = None | ||
self.b_att_start_loc = None | ||
self.total_cache_num = None | ||
# self.window_postion = None | ||
|
||
def init_some_extra_state(self, model, input_ids: torch.Tensor): | ||
self.sliding_window = model.config["sliding_window"] | ||
if self.is_prefill: | ||
b_seq_len_numpy = self.b_seq_len.cpu().numpy() | ||
position_ids = torch.from_numpy( | ||
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0) | ||
).cuda() | ||
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) | ||
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) | ||
position_ids = None | ||
else: | ||
position_ids = self.b_seq_len - 1 | ||
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) | ||
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) | ||
self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() | ||
|
||
self.b_att_seq_len = self.b_seq_len.clone() | ||
self.b_att_start_loc = self.b_start_loc.clone() | ||
self.b_start_loc_window = self.b_start_loc.clone() | ||
self.total_cache_num = 0 | ||
for i in range(0, self.batch_size): | ||
if self.sliding_window < self.b_seq_len[i]: | ||
self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window | ||
self.b_att_seq_len[i] = self.sliding_window | ||
else: | ||
self.b_start_loc_window[i] = 0 | ||
self.b_att_seq_len[i] = self.b_seq_len[i] | ||
self.b_att_start_loc[i] = self.total_cache_num | ||
self.total_cache_num += self.b_att_seq_len[i] | ||
return |
Empty file.
140 changes: 140 additions & 0 deletions
140
lightllm/models/qwen2/layer_infer/transformer_layer_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import torch | ||
import torch.functional as F | ||
import torch.distributed as dist | ||
import numpy as np | ||
from typing import Tuple | ||
import triton | ||
from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
|
||
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer, rotary_emb_fwd | ||
from lightllm.models.qwen2.infer_struct import Qwen2InferStateInfo | ||
|
||
from lightllm.models.mistral.triton_kernel.context_flashattention_nopad import context_attention_fwd | ||
from lightllm.models.mistral.triton_kernel.token_attention_nopad_att1 import token_att_fwd | ||
from lightllm.models.mistral.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 | ||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd | ||
|
||
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv | ||
from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight | ||
|
||
|
||
class Qwen2TransformerLayerInfer(LlamaTransformerLayerInfer): | ||
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): | ||
super().__init__(layer_num, tp_rank, world_size, network_config, mode) | ||
return | ||
|
||
def _get_qkv( | ||
self, input, cache_kv, infer_state: Qwen2InferStateInfo, layer_weight: Qwen2TransformerLayerWeight | ||
) -> torch.Tensor: | ||
q = torch.addmm( | ||
layer_weight.q_bias_, | ||
input.view(-1, self.embed_dim_), | ||
layer_weight.q_weight_, | ||
beta=1.0, | ||
alpha=1.0, | ||
) | ||
torch.addmm( | ||
layer_weight.kv_bias_, | ||
input.view(-1, self.embed_dim_), | ||
layer_weight.kv_weight_, | ||
beta=1.0, | ||
alpha=1.0, | ||
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), | ||
) | ||
rotary_emb_fwd( | ||
q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
cache_kv[:, 0 : self.tp_k_head_num_, :], | ||
infer_state.position_cos, | ||
infer_state.position_sin, | ||
) | ||
return q, cache_kv | ||
|
||
def _bind_func(self): | ||
self._token_attention_kernel = self._token_decode_attention_normal | ||
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal | ||
return | ||
|
||
def _context_attention_kernel( | ||
self, q, kv, infer_state: Qwen2InferStateInfo, layer_weight, out=None | ||
) -> torch.Tensor: | ||
o_tensor = torch.empty_like(q) if out is None else out | ||
context_attention_fwd( | ||
q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
kv[:, 0 : self.tp_k_head_num_, :], | ||
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], | ||
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.max_len_in_batch, | ||
infer_state.sliding_window, | ||
) | ||
return o_tensor | ||
|
||
def _token_decode_attention_normal(self, q, infer_state: Qwen2InferStateInfo, layer_weight, out=None): | ||
total_token_num = infer_state.total_cache_num | ||
batch_size = infer_state.batch_size | ||
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) | ||
|
||
att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda") | ||
|
||
token_att_fwd( | ||
q.view(calcu_shape1), | ||
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], | ||
att_m_tensor, | ||
infer_state.req_manager.req_to_token_indexs, | ||
infer_state.b_req_idx, | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.b_start_loc_window, | ||
infer_state.b_att_start_loc, | ||
infer_state.b_att_seq_len, | ||
infer_state.sliding_window, | ||
) | ||
|
||
o_tensor = torch.empty_like(q) if out is None else out | ||
|
||
if triton.__version__ == "2.0.0": | ||
prob = torch.empty_like(att_m_tensor) | ||
token_softmax_fwd( | ||
att_m_tensor, infer_state.b_att_start_loc, infer_state.b_att_seq_len, prob, infer_state.sliding_window | ||
) | ||
att_m_tensor = None | ||
token_att_fwd2( | ||
prob, | ||
infer_state.mem_manager.kv_buffer[self.layer_num_][ | ||
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : | ||
], | ||
o_tensor.view(calcu_shape1), | ||
infer_state.req_manager.req_to_token_indexs, | ||
infer_state.b_req_idx, | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.b_start_loc_window, | ||
infer_state.b_att_start_loc, | ||
infer_state.b_att_seq_len, | ||
) | ||
prob = None | ||
return o_tensor | ||
elif triton.__version__ >= "2.1.0": | ||
from lightllm.models.mistral.triton_kernel.token_attention_softmax_and_reducev import ( | ||
token_softmax_reducev_fwd, | ||
) | ||
|
||
token_softmax_reducev_fwd( | ||
att_m_tensor, | ||
infer_state.mem_manager.kv_buffer[self.layer_num_][ | ||
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : | ||
], | ||
o_tensor.view(calcu_shape1), | ||
infer_state.req_manager.req_to_token_indexs, | ||
infer_state.b_req_idx, | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.b_start_loc_window, | ||
infer_state.b_att_start_loc, | ||
infer_state.b_att_seq_len, | ||
infer_state.other_kv_index, | ||
) | ||
return o_tensor | ||
else: | ||
raise Exception("not support triton version") |
Empty file.
40 changes: 40 additions & 0 deletions
40
lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
import numpy as np | ||
from lightllm.common.basemodel import PreAndPostLayerWeight | ||
|
||
|
||
class Qwen2PreAndPostLayerWeight(PreAndPostLayerWeight): | ||
def __init__(self, tp_rank, world_size, data_type, network_config, mode): | ||
super().__init__(tp_rank, world_size, data_type, network_config, mode) | ||
return | ||
|
||
def load_hf_weights(self, weights): | ||
|
||
vob_size = self.network_config_["vocab_size"] | ||
split_vob_size = vob_size // self.world_size_ | ||
|
||
if "model.embed_tokens.weight" in weights: | ||
self.wte_weight_ = self._cuda( | ||
weights["model.embed_tokens.weight"][ | ||
split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), : | ||
] | ||
) | ||
if "lm_head.weight" in weights: | ||
self.lm_head_weight_ = self._cuda( | ||
weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :] | ||
) | ||
if "model.norm.weight" in weights: | ||
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) | ||
|
||
return | ||
|
||
def verify_load(self): | ||
errors = "weights load not ok" | ||
weights = [ | ||
self.wte_weight_, | ||
self.lm_head_weight_, | ||
self.final_norm_weight_, | ||
] | ||
for i in range(len(weights)): | ||
assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
return |
113 changes: 113 additions & 0 deletions
113
lightllm/models/qwen2/layer_weights/transformer_layer_weight.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import torch | ||
import math | ||
import numpy as np | ||
from lightllm.common.basemodel import TransformerLayerWeight | ||
|
||
|
||
class Qwen2TransformerLayerWeight(TransformerLayerWeight): | ||
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): | ||
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) | ||
|
||
def _load_qkvo_weights(self, weights): | ||
# input norm | ||
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: | ||
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) | ||
|
||
n_embed = self.network_config_["hidden_size"] | ||
q_split_n_embed = n_embed // self.world_size_ | ||
kv_split_n_embed = ( | ||
n_embed | ||
// self.network_config_["num_attention_heads"] | ||
* self.network_config_["num_key_value_heads"] | ||
// self.world_size_ | ||
) | ||
|
||
# q k v weights | ||
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: | ||
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] | ||
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) | ||
|
||
if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: | ||
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] | ||
k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.k_weight_ = k_weight_.transpose(0, 1) | ||
|
||
if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: | ||
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] | ||
v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.v_weight_ = v_weight_.transpose(0, 1) | ||
|
||
# attention output dense params | ||
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: | ||
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] | ||
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] | ||
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) | ||
|
||
# q k v bias | ||
if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: | ||
q_bias_ = self._cuda(weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"]) | ||
self.q_bias_ = q_bias_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] | ||
|
||
if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: | ||
k_bias = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"] | ||
self.k_bias_ = k_bias[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)] | ||
|
||
if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: | ||
v_bias = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"] | ||
self.v_bias_ = v_bias[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)] | ||
|
||
self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) | ||
self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) | ||
|
||
def _load_ffn_weights(self, weights): | ||
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: | ||
self.ffn_norm_weight_ = self._cuda( | ||
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] | ||
) | ||
|
||
inter_size = self.network_config_["intermediate_size"] | ||
split_inter_size = inter_size // self.world_size_ | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: | ||
up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ | ||
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : | ||
] | ||
self.up_proj = up_proj.transpose(0, 1) | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: | ||
gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ | ||
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : | ||
] | ||
self.gate_proj = gate_proj.transpose(0, 1) | ||
|
||
self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: | ||
self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ | ||
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) | ||
] | ||
self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) | ||
return | ||
|
||
def load_hf_weights(self, weights): | ||
self._load_qkvo_weights(weights) | ||
self._load_ffn_weights(weights) | ||
return | ||
|
||
def verify_load(self): | ||
errors = "weights load not ok" | ||
weights = [ | ||
self.att_norm_weight_, | ||
self.q_weight_, | ||
self.kv_weight_, | ||
self.q_bias_, | ||
self.kv_bias_, | ||
self.o_weight_, | ||
self.ffn_norm_weight_, | ||
self.gate_up_proj, | ||
self.down_proj, | ||
] | ||
for i in range(len(weights)): | ||
assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
return |
Oops, something went wrong.