Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cohere model for v1 and plus #416

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
139 changes: 139 additions & 0 deletions lightllm/models/cohere/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch
import torch.distributed as dist
import numpy as np

from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward
from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo

from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from einops import rearrange
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.basemodel import PostLayerInferTpl


class CoherePostLayerInfer(PostLayerInferTpl):
def __init__(self, tp_rank, world_size, network_config, mode):
super().__init__(tp_rank, world_size, network_config, mode)
self.eps_ = network_config["layer_norm_eps"]
self.vocab_size_ = network_config["vocab_size"]
self.embed_dim_ = network_config["n_embed"]
self.logits_scale = network_config["logit_scale"]
return

def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor:
return layernorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_)

def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo):
if infer_state.is_splitfuse:
# for SplitFuse
batch_size = infer_state.batch_size
last_input = torch.empty(
(batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
)
tmp_ = torch.cat(
[
torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"),
infer_state.prefill_b_seq_len - infer_state.prefill_b_split_ready_cache_len,
],
dim=0,
)
last_index = torch.cumsum(tmp_, dim=0, dtype=torch.long) - 1
last_input[:, :] = input_embdings[last_index, :]
return last_input, batch_size

if infer_state.is_prefill and infer_state.is_token_healing:
batch_size = infer_state.batch_size
b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy()
select_index = []
start_index = 0
select_token_num = 0
for cur_len in b_seq_len_numpy:
if cur_len == 1:
select_index.append(start_index + cur_len - 1)
start_index += cur_len
select_token_num += 1
else:
select_index.append(start_index + cur_len - 2)
select_index.append(start_index + cur_len - 1)
start_index += cur_len
select_token_num += 2

last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device)
last_input = torch.empty(
(select_token_num, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
)

last_input[:, :] = input_embdings[last_index, :]
return last_input, select_token_num

if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logics:
batch_size = infer_state.batch_size
last_input = torch.empty(
(batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
)
last_index = (
torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
)
last_input[:, :] = input_embdings[last_index, :]
return last_input, batch_size

if not infer_state.is_splitfuse and infer_state.is_prefill and infer_state.return_all_prompt_logics:
total_tokens = infer_state.total_token_num
return input_embdings, total_tokens

if not infer_state.is_splitfuse and not infer_state.is_prefill:
batch_size = infer_state.batch_size
return input_embdings[-batch_size:, :], batch_size

assert False, "Error State"

def soft_max(self, data):
return torch.softmax(data.permute(1, 0).float(), dim=-1)

def token_forward(
self,
input_embdings,
infer_state: LlamaInferStateInfo,
layer_weight: LlamaPreAndPostLayerWeight,
return_logics=False,
):
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
input_embdings_dtype = input_embdings.dtype
input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)
last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num)
logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input)

last_input = None
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty(
(self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype
)
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
dist.all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],
logic_batch,
group=None,
async_op=False,
)
gather_data = gather_data * self.logits_scale
logic_batch = None

if not return_logics:
prob_out = self.soft_max(gather_data)
gather_data = None
return prob_out
else:
ans_logics = gather_data.permute(1, 0).float()
gather_data = None
return ans_logics

# @mark_cost_time("splitfuse post forward")
def splitfuse_forward(
self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight: BaseLayerWeight, return_logics=False
):
return self.token_forward(input_embdings, infer_state, layer_weight, return_logics=return_logics)
174 changes: 174 additions & 0 deletions lightllm/models/cohere/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import torch
from functools import partial

from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo
from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, mh_layernorm_forward
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer

from lightllm.models.llama.infer_struct import LlamaInferStateInfo
import torch.distributed as dist

from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.utils.infer_utils import mark_cost_time



class CohereTransformerLayerInfer(LlamaTransformerLayerInfer):
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
network_config["rms_norm_eps"] = network_config["layer_norm_eps"] # cohere uses layer_norm_eps
self.use_qk_norm = network_config.get("use_qk_norm", False)
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
self.eps_ = network_config["layer_norm_eps"] # overwrite eps
self._bind_func()
return

def _att_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: CohereTransformerLayerWeight):
return layernorm_forward(input, weight=layer_weight.att_norm_weight_, eps=self.eps_)

def _q_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: CohereTransformerLayerWeight):
return mh_layernorm_forward(input, weight=layer_weight.q_norm_weight_, eps=self.eps_)

def _k_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: CohereTransformerLayerWeight):
return mh_layernorm_forward(input, weight=layer_weight.k_norm_weight_, eps=self.eps_)

def _bind_norm(self):
self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self)
self._ffn_norm = None # no ffn norm in cohere models
self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) if self.use_qk_norm else None
self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) if self.use_qk_norm else None

def _get_qkv(
self, input, cache_kv, infer_state, layer_weight
) -> torch.Tensor:
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
torch.mm(
input.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
if self.use_qk_norm:
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
k = cache_kv[:, 0 : self.tp_k_head_num_, :]
q = self._q_norm(q, infer_state, layer_weight)
cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight)
rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, 0 : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
return q, cache_kv

@mark_cost_time("trans context ffn forward time cost") # dont to remove this, will make performence down, did not know why
def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = input_embdings
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
infer_state._ffn_out = ffn_out
return

def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = input_embdings
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
infer_state._ffn_out = ffn_out
return

# @mark_cost_time("trans context ffn forward time cost") # dont to remove this, will make performence down, did not know why
def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight):
input1 = input_embdings
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
infer_state._ffn_out = ffn_out
return

@mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why
def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight):
input1 = input_embding
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
input1 = None
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
infer_state._attn_out = o
return

# this impl dont to use @mark_cost_time
def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight):
input1 = input_embding
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
infer_state._attn_out = o
return

# @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why
def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight):
input1 = input_embding
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
infer_state._attn_out = o
return

def _cohere_residual(self, input_embdings, infer_state: InferStateInfo):
emb_addr = input_embdings.data_ptr()
attn_out_addr = infer_state._attn_out.data_ptr()
ffn_addr = infer_state._ffn_out.data_ptr()
assert emb_addr != attn_out_addr
assert emb_addr != ffn_addr
assert attn_out_addr != ffn_addr
input_embdings.add_(infer_state._attn_out.view(-1, self.embed_dim_) + infer_state._ffn_out.view(-1, self.embed_dim_))

def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
self._context_attention(input1,
infer_state,
layer_weight=layer_weight)
self._context_ffn(input1, infer_state, layer_weight)
self._cohere_residual(input_embdings, infer_state)
return input_embdings

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
self._token_attention(input1,
infer_state,
layer_weight=layer_weight)
self._token_ffn(input1, infer_state, layer_weight)
self._cohere_residual(input_embdings, infer_state)
return input_embdings

def splitfuse_forward(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
self._splitfuse_attention(input1,
infer_state,
layer_weight=layer_weight)
self._splitfuse_ffn(input1, infer_state, layer_weight)
self._cohere_residual(input_embdings, infer_state)
return input_embdings
Empty file.
36 changes: 36 additions & 0 deletions lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import numpy as np

from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight

class CoherePreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def load_hf_weights(self, weights):
vob_size = self.network_config_["vocab_size"]
tie_weight = self.network_config_.get("tie_word_embeddings", True)
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64)
split_start = split_indexes[self.tp_rank_]
split_end = split_indexes[self.tp_rank_ + 1]
if "model.embed_tokens.weight" in weights:
# print(weights['model.embed_tokens.weight'].shape)
self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :])
if tie_weight:
self.lm_head_weight_ = self.wte_weight_
if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])
if "model.lm_head.weight" in weights:
self.lm_head_weight_ = self._cuda(weights["model.lm_head.weight"])
return

def verify_load(self):
super().verify_load()

errors = "tie weights load not ok"
tie_weight = self.network_config_.get("tie_word_embeddings", True)
if tie_weight:
assert self.lm_head_weight_ is not None, errors
assert self.wte_weight_ is self.lm_head_weight_, errors
else:
assert self.lm_head_weight_ is not None, errors
assert self.wte_weight_ is not None, errors
assert self.wte_weight_ is not self.lm_head_weight_, errors

Loading