Skip to content

Commit

Permalink
Add Qwen2 Model (#340)
Browse files Browse the repository at this point in the history
Co-authored-by: hiworldwzj <[email protected]>
  • Loading branch information
flyinglandlord and hiworldwzj authored Mar 7, 2024
1 parent f7b9937 commit 414aa08
Show file tree
Hide file tree
Showing 9 changed files with 525 additions and 58 deletions.
Empty file.
46 changes: 46 additions & 0 deletions lightllm/models/qwen2/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import numpy as np
from lightllm.common.basemodel import InferStateInfo
from lightllm.common.req_manager import ReqManager


class Qwen2InferStateInfo(InferStateInfo):
def __init__(self):
super().__init__()
self.sliding_window = None
self.b_start_loc_window = None
self.b_att_seq_len = None
self.b_att_start_loc = None
self.total_cache_num = None
# self.window_postion = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.sliding_window = model.config["sliding_window"]
if self.is_prefill:
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
position_ids = torch.from_numpy(
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)
).cuda()
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
position_ids = None
else:
position_ids = self.b_seq_len - 1
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item()

self.b_att_seq_len = self.b_seq_len.clone()
self.b_att_start_loc = self.b_start_loc.clone()
self.b_start_loc_window = self.b_start_loc.clone()
self.total_cache_num = 0
for i in range(0, self.batch_size):
if self.sliding_window < self.b_seq_len[i]:
self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window
self.b_att_seq_len[i] = self.sliding_window
else:
self.b_start_loc_window[i] = 0
self.b_att_seq_len[i] = self.b_seq_len[i]
self.b_att_start_loc[i] = self.total_cache_num
self.total_cache_num += self.b_att_seq_len[i]
return
Empty file.
140 changes: 140 additions & 0 deletions lightllm/models/qwen2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from typing import Tuple
import triton
from lightllm.models.llama.infer_struct import LlamaInferStateInfo

from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer, rotary_emb_fwd
from lightllm.models.qwen2.infer_struct import Qwen2InferStateInfo

from lightllm.models.mistral.triton_kernel.context_flashattention_nopad import context_attention_fwd
from lightllm.models.mistral.triton_kernel.token_attention_nopad_att1 import token_att_fwd
from lightllm.models.mistral.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd

from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight


class Qwen2TransformerLayerInfer(LlamaTransformerLayerInfer):
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
return

def _get_qkv(
self, input, cache_kv, infer_state: Qwen2InferStateInfo, layer_weight: Qwen2TransformerLayerWeight
) -> torch.Tensor:
q = torch.addmm(
layer_weight.q_bias_,
input.view(-1, self.embed_dim_),
layer_weight.q_weight_,
beta=1.0,
alpha=1.0,
)
torch.addmm(
layer_weight.kv_bias_,
input.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
beta=1.0,
alpha=1.0,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, 0 : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
return q, cache_kv

def _bind_func(self):
self._token_attention_kernel = self._token_decode_attention_normal
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
return

def _context_attention_kernel(
self, q, kv, infer_state: Qwen2InferStateInfo, layer_weight, out=None
) -> torch.Tensor:
o_tensor = torch.empty_like(q) if out is None else out
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
infer_state.sliding_window,
)
return o_tensor

def _token_decode_attention_normal(self, q, infer_state: Qwen2InferStateInfo, layer_weight, out=None):
total_token_num = infer_state.total_cache_num
batch_size = infer_state.batch_size
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)

att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda")

token_att_fwd(
q.view(calcu_shape1),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],
att_m_tensor,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.sliding_window,
)

o_tensor = torch.empty_like(q) if out is None else out

if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
token_softmax_fwd(
att_m_tensor, infer_state.b_att_start_loc, infer_state.b_att_seq_len, prob, infer_state.sliding_window
)
att_m_tensor = None
token_att_fwd2(
prob,
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(calcu_shape1),
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
)
prob = None
return o_tensor
elif triton.__version__ >= "2.1.0":
from lightllm.models.mistral.triton_kernel.token_attention_softmax_and_reducev import (
token_softmax_reducev_fwd,
)

token_softmax_reducev_fwd(
att_m_tensor,
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(calcu_shape1),
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.other_kv_index,
)
return o_tensor
else:
raise Exception("not support triton version")
Empty file.
40 changes: 40 additions & 0 deletions lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import numpy as np
from lightllm.common.basemodel import PreAndPostLayerWeight


class Qwen2PreAndPostLayerWeight(PreAndPostLayerWeight):
def __init__(self, tp_rank, world_size, data_type, network_config, mode):
super().__init__(tp_rank, world_size, data_type, network_config, mode)
return

def load_hf_weights(self, weights):

vob_size = self.network_config_["vocab_size"]
split_vob_size = vob_size // self.world_size_

if "model.embed_tokens.weight" in weights:
self.wte_weight_ = self._cuda(
weights["model.embed_tokens.weight"][
split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :
]
)
if "lm_head.weight" in weights:
self.lm_head_weight_ = self._cuda(
weights["lm_head.weight"][split_vob_size * self.tp_rank_ : split_vob_size * (self.tp_rank_ + 1), :]
)
if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

return

def verify_load(self):
errors = "weights load not ok"
weights = [
self.wte_weight_,
self.lm_head_weight_,
self.final_norm_weight_,
]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
113 changes: 113 additions & 0 deletions lightllm/models/qwen2/layer_weights/transformer_layer_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import math
import numpy as np
from lightllm.common.basemodel import TransformerLayerWeight


class Qwen2TransformerLayerWeight(TransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)

def _load_qkvo_weights(self, weights):
# input norm
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights:
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"])

n_embed = self.network_config_["hidden_size"]
q_split_n_embed = n_embed // self.world_size_
kv_split_n_embed = (
n_embed
// self.network_config_["num_attention_heads"]
* self.network_config_["num_key_value_heads"]
// self.world_size_
)

# q k v weights
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights:
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"]
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :]
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1))

if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights:
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"]
k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.k_weight_ = k_weight_.transpose(0, 1)

if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights:
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"]
v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.v_weight_ = v_weight_.transpose(0, 1)

# attention output dense params
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights:
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"]
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1))

# q k v bias
if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights:
q_bias_ = self._cuda(weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"])
self.q_bias_ = q_bias_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]

if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights:
k_bias = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"]
self.k_bias_ = k_bias[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)]

if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights:
v_bias = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"]
self.v_bias_ = v_bias[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)]

self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1)
self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0)

def _load_ffn_weights(self, weights):
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights:
self.ffn_norm_weight_ = self._cuda(
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]
)

inter_size = self.network_config_["intermediate_size"]
split_inter_size = inter_size // self.world_size_

if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights:
up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.up_proj = up_proj.transpose(0, 1)

if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights:
gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.gate_proj = gate_proj.transpose(0, 1)

self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1)

if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights:
self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
]
self.down_proj = self._cuda(self.down_proj.transpose(0, 1))
return

def load_hf_weights(self, weights):
self._load_qkvo_weights(weights)
self._load_ffn_weights(weights)
return

def verify_load(self):
errors = "weights load not ok"
weights = [
self.att_norm_weight_,
self.q_weight_,
self.kv_weight_,
self.q_bias_,
self.kv_bias_,
self.o_weight_,
self.ffn_norm_weight_,
self.gate_up_proj,
self.down_proj,
]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
Loading

0 comments on commit 414aa08

Please sign in to comment.