Skip to content

Commit

Permalink
Support internlm2 wquant. (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Mar 6, 2024
1 parent b1a73ec commit f7b9937
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 11 deletions.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from lightllm.models.internlm_wquant.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeightQuantized


class Internlm2TransformerLayerWeightQuantized(InternlmTransformerLayerWeightQuantized):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
return

def _load_qkvo_weights(self, weights):
# input layernorm params
if f"model.layers.{self.layer_num_}.attention_norm.weight" in weights:
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.attention_norm.weight"])

n_embed = self.network_config_["hidden_size"]
q_split_n_embed = n_embed // self.world_size_
kv_split_n_embed = (
n_embed
// self.network_config_["num_attention_heads"]
* self.network_config_["num_key_value_heads"]
// self.world_size_
)
head_dim = n_embed // self.network_config_["num_attention_heads"]
# q k v weights for llama
if f"model.layers.{self.layer_num_}.attention.wqkv.weight" in weights:
qkv_weight_ = weights[f"model.layers.{self.layer_num_}.attention.wqkv.weight"]
q_groups = self.network_config_["num_attention_heads"] // self.network_config_["num_key_value_heads"]
qkv_weight_ = qkv_weight_.reshape(self.network_config_["num_key_value_heads"], q_groups + 2, head_dim, -1)
q_weight_ = qkv_weight_[:, :q_groups, :, :].reshape(-1, qkv_weight_.shape[-1])
q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) :].transpose(0, 1)
self.q_weight_ = self.quantize_weight(q_weight_)

k_weight_ = qkv_weight_[:, -2, :, :].reshape(-1, qkv_weight_.shape[-1])
self.k_weight_ = k_weight_[
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) :
].transpose(0, 1)
v_weight_ = qkv_weight_[:, -1, :, :].reshape(-1, qkv_weight_.shape[-1])
self.v_weight_ = v_weight_[
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) :
].transpose(0, 1)

self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1, handle_func=self.quantize_weight)

# attention output dense params
if f"model.layers.{self.layer_num_}.attention.wo.weight" in weights:
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.attention.wo.weight"]
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]
self.o_weight_ = self.quantize_weight(self.o_weight_.transpose(0, 1))
if f"model.layers.{self.layer_num_}.attention.wo.bias" in weights:
self.o_bias_ = weights[f"model.layers.{self.layer_num_}.attention.wo.bias"]
self.o_bias_ = self._cuda(self.o_bias_)
return

def _load_ffn_weights(self, weights):
if f"model.layers.{self.layer_num_}.ffn_norm.weight" in weights:
self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.ffn_norm.weight"])

inter_size = self.network_config_["intermediate_size"]
split_inter_size = inter_size // self.world_size_

if f"model.layers.{self.layer_num_}.feed_forward.w3.weight" in weights:
up_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w3.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.up_proj = up_proj.transpose(0, 1)

if f"model.layers.{self.layer_num_}.feed_forward.w1.weight" in weights:
gate_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w1.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.gate_proj = gate_proj.transpose(0, 1)

self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1, handle_func=self.quantize_weight)

if f"model.layers.{self.layer_num_}.feed_forward.w2.weight" in weights:
self.down_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w2.weight"][
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
]
self.down_proj = self.quantize_weight(self.down_proj.transpose(0, 1))
return
33 changes: 33 additions & 0 deletions lightllm/models/internlm2_wquant/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import json
import torch

from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
from lightllm.models.internlm2_wquant.layer_weights.transformer_layer_weight import Internlm2TransformerLayerWeightQuantized
from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant
from lightllm.common.mem_utils import select_mem_manager_class


class Internlm2TpPartModelWQuant(InternlmTpPartModelWQuant):
# weight class
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
transformer_weight_class = Internlm2TransformerLayerWeightQuantized

def __init__(self, kvargs):
super().__init__(kvargs)

def _verify_params(self):
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
assert any("w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
assert self.config["num_key_value_heads"] % self.world_size_ == 0
assert self.config["num_attention_heads"] % self.world_size_ == 0
return

def _init_mem_manager(self):
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
dtype=torch.float16,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"],
always_copy=True)
return
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def verify_load(self):

# handle internlm 20b, which has no bias, so set q k v o bias to zero
if not self.network_config_.get("bias", True):
for layer_type in ("q", "k", "v", "o"):
for layer_type in ("q", "kv", "o"):
attr_name = f"{layer_type}_bias_"
if hasattr(self, attr_name):
continue
Expand Down Expand Up @@ -44,15 +44,13 @@ def _load_qkvo_weights(self, weights):
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"])

n_embed = self.network_config_["hidden_size"]

q_split_n_embed = n_embed // self.world_size_
kv_split_n_embed = (
n_embed
// self.network_config_["num_attention_heads"]
* self.network_config_["num_key_value_heads"]
// self.world_size_
)

# q k v weights for llama
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights:
q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"]
Expand All @@ -71,17 +69,14 @@ def _load_qkvo_weights(self, weights):
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"]
k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_)

if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights:
self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)
]

if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights:
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"]
v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_)

if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights:
self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)
Expand Down
13 changes: 8 additions & 5 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lightllm.models.stablelm.model import StablelmTpPartModel
from lightllm.models.internlm2.model import Internlm2TpPartModel
from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant
from lightllm.models.internlm2_wquant.model import Internlm2TpPartModelWQuant
from lightllm.models.yi.model import YiTpPartModel
from lightllm.models.mistral.model import MistralTpPartModel
from lightllm.models.minicpm.model import MiniCPMTpPartModel
Expand Down Expand Up @@ -126,14 +127,16 @@ def exposed_init_model(self, kvargs):
self.model = StarcoderTpPartModel(model_kvargs)
elif self.model_type == 'chatglm':
self.model = ChatGlm2TpPartModel(model_kvargs)
elif self.model_type == 'internlm' or self.model_type == 'internlm2':
elif self.model_type == 'internlm':
if any('w8a16' in mode_ or 'w4a16' in mode_ for mode_ in self.mode):
self.model = InternlmTpPartModelWQuant(model_kvargs)
else:
if model_cfg["architectures"][0] == 'InternLM2ForCausalLM':
self.model = Internlm2TpPartModel(model_kvargs)
else:
self.model = InternlmTpPartModel(model_kvargs)
self.model = InternlmTpPartModel(model_kvargs)
elif self.model_type == 'internlm2':
if any('w8a16' in mode_ or 'w4a16' in mode_ for mode_ in self.mode):
self.model = Internlm2TpPartModelWQuant(model_kvargs)
else:
self.model = Internlm2TpPartModel(model_kvargs)
elif self.model_type == "Yi":
self.model = YiTpPartModel(model_kvargs)
elif self.model_type == "mistral":
Expand Down

0 comments on commit f7b9937

Please sign in to comment.