-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][wenet/LLM] support LLMs #2460
base: main
Are you sure you want to change the base?
Conversation
为什么不把embeding和out 放到decoderonly里边? 其他模态的注入是从embeding开始的,保持decoder only 有embeding的入参。 如果embeing和out share weight,fsdp 需要embeding 和out 在同一个level上, 我们经常会扩充词表,resize embed 和resize out,放最外层不影响decoderonly |
gemma 精度测试 # configs = {"decoder": "decoder_only", "output_dim": 256000, "model_conf": {}}
import torch
from wenet.text.LLM.script.convert_gemma_to_wenet_config_and_ckpt import (
get_config_for_2b, get_config_for_7b)
from wenet.utils.init_model import init_model
from gemma.model import GemmaForCausalLM
from gemma.config import (get_config_for_2b as google_2b_config_fn,
get_config_for_7b as google_7b_config_fn)
import argparse
def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument(
'--gemma_ckpt',
required=True,
help='https://www.kaggle.com/models/google/gemma/frameworks/pyTorch')
parser.add_argument(
'--gemma_tokenizer',
required=True,
help='https://www.kaggle.com/models/google/gemma/frameworks/pyTorch')
parser.add_argument(
'--wenet_gemma_ckpt',
required=True,
help='https://www.kaggle.com/models/google/gemma/frameworks/pyTorch')
parser.add_argument('--model_size', type=str, required=True)
args = parser.parse_args()
return args
args = get_args()
args.jit = False
layers = 18 if args.model_size == '2b' else 28
if args.model_size == '2b':
config = get_config_for_2b()
else:
config = get_config_for_7b()
model_conf = {
'model': 'causal_lm',
'output_dim': config.vocab_size,
'decoder': 'decoder_only',
'tokenizer_conf': {
"special_tokens": {
'sos': 0,
'eos': 1
}
}
}
decoder_conf = {}
decoder_conf['n_kv_head'] = config.num_key_value_heads
decoder_conf['head_dim'] = config.head_dim
decoder_conf['hidden_size'] = config.hidden_size
decoder_conf['attention_heads'] = config.num_attention_heads
decoder_conf['linear_units'] = config.intermediate_size
decoder_conf['num_blocks'] = layers
decoder_conf['max_position_embeding'] = 8192
decoder_conf['activation_type'] = 'gelu'
decoder_conf['gelu_approximate'] = 'tanh'
decoder_conf['norm_eps'] = config.rms_norm_eps
decoder_conf['use_sdpa'] = True
model_conf['decoder_conf'] = decoder_conf
model_conf['model_conf'] = {}
args.checkpoint = args.wenet_gemma_ckpt
model, _ = init_model(args, model_conf)
model.eval()
# get google gemma model
if args.model_size == '2b':
google_config = google_2b_config_fn()
else:
google_config = google_7b_config_fn()
google_config.tokenizer = args.gemma_tokenizer
google_gemma = GemmaForCausalLM(google_config)
google_gemma.load_weights(
args.gemma_ckpt)
google_gemma.eval()
scale = google_config.hidden_size
batch_size = torch.randint(2, 10, ())
seq_len = torch.randint(3, 20, ())
text = torch.randint(0, config.vocab_size, (batch_size, seq_len))
def google_forward(google_gemma,
batch_size,
token_ids,
seq_len,
scale,
layers=18):
google_freqs_cis = google_gemma.freqs_cis
google_emb = google_gemma.embedder
google_gemma = google_gemma.model
input_positions_tensor = torch.arange(0, seq_len)
google_freqs_cis = google_freqs_cis.index_select(0, input_positions_tensor)
google_hidden_states = google_emb(token_ids)
google_hidden_states = google_hidden_states * (scale**0.5)
# mask_tensor = torch.full((2, 1, 10, 10), -2.3819763e38).to(torch.float)
mask_tensor = torch.full((batch_size, 1, seq_len, seq_len),
0).to(torch.float)
kv_caches = []
for _ in range(layers):
size = (batch_size, seq_len, google_config.num_key_value_heads,
google_config.head_dim)
k_cache = torch.zeros(size=size)
v_cache = torch.zeros(size=size)
kv_caches.append((k_cache, v_cache))
google_output = google_gemma(
google_hidden_states,
google_freqs_cis,
input_positions_tensor,
kv_caches,
mask_tensor,
)
google_output = torch.matmul(google_output, google_emb.weight.T)
return google_output
def wenet_forward(wenet_model, batch_size, token_ids, seq_len, layers=18):
hidden_states = wenet_model.embed(token_ids)
wenet_kv_caches = []
for _ in range(layers):
size = (0, 0, 0, 0)
k_cache = torch.zeros(size=size)
v_cache = torch.zeros(size=size)
wenet_kv_caches.append((k_cache, v_cache))
att_mask_tensor = torch.ones(batch_size,
seq_len,
seq_len,
dtype=torch.bool)
wenet_output, _ = model.decoder(hidden_states,
att_mask_tensor.squeeze(1),
kv_caches=wenet_kv_caches)
wenet_output = model.out(wenet_output)
return wenet_output
wenet_output = wenet_forward(model, batch_size, text, seq_len, layers)
google_output = google_forward(google_gemma, batch_size, text, seq_len, scale,
layers)
print(wenet_output)
print(google_output)
assert torch.allclose(wenet_output, google_output) |
1ea0839
to
64ff835
Compare
31a8dcd
to
782998d
Compare
解释下这里为什么要把shape 变成[bs, seq_len,head, head_dim] 实测[bs, seq_len,head, head_dim], 对head_dim 上apply pos等操作要慢于[bs,head,seq_len, head_dim] 6s vs 2s (长度为300) ref: https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L256 所以其他xxx attention 是否也需要有对应修改? |
周神,torch官方也有个llama微调的代码:https://github.com/pytorch/torchtune |
嗯 这个有看过。 不过我们最终目的不是llm 而是为了语音理解大模型和语音合成 而且大模型训练 有自己的设计原则和技巧 我们需要把优秀的组件 继承过来 |
该pr会拆分成以下加个pr
|
为下一步的SpeechLLM 打基础
TODO
make it works
convert some model
Llama3 70b 模型并行度为8, 分别在attention的q,k,v 和feed fowrad weight上进行了col row等的切分,需要引入fairscale, 做模型并行。 并且官方给了8个pt, 每个16G左右。
TODO