Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] 我有一段代码,不知道怎么使用LMDeploy去加速它 #2958

Open
CallmeZhangChenchen opened this issue Dec 26, 2024 · 5 comments
Assignees

Comments

@CallmeZhangChenchen
Copy link

Motivation

工程里面有一段代码,想用LMDeploy加速它
https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/llm/llm.py

Related resources

我把它的主体结构截出来了

网络结构,标准的QWen2-0.5B

from transformers import Qwen2ForCausalLM
class Qwen2Encoder(torch.nn.Module):
    def __init__(self, pretrain_path):
        super().__init__()
        self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)

    def forward_one_step(self, xs, masks, cache=None):
        input_masks = masks[:, -1, :]
        outs = self.model(
            inputs_embeds=xs,
            attention_mask=input_masks,
            output_hidden_states=True,
            return_dict=True,
            use_cache=True,
            past_key_values=cache,
        )
        xs = outs.hidden_states[-1]
        new_cache = outs.past_key_values
        
        return xs, new_cache

decode 循环

       ...
        # 4. cal min/max_length
        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)

        # 5. step by step decode
        out_tokens = []
        cache = None
        for i in range(max_len):
            begin = time.time()
            y_pred, cache = self.llm.forward_one_step(lm_input,
                                                      masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
                                                      cache=cache)
            print('self.llm.forward_one_step:', time.time() - begin)

            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
            if top_ids == self.speech_token_size:
                break
            if top_ids > self.speech_token_size:
                continue
            # in stream mode, yield token one by one
            yield top_ids
            out_tokens.append(top_ids)
            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

我想问一下,这段代码可以使用LNDeploy 直接进行加速吗,或者我需要改什么

PS: 我在考虑直接使用LMDeploy 加速一个 forward_one_step (就是一个decode),其他的代码还复用之前的工程结构,
我找到了一个Demo,https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/turbomind/decode.py ,但是decode的参数我不知道怎么传 ^_^~~

Additional context

No response

@CallmeZhangChenchen
Copy link
Author

这段代码跟传统的语言大模型有几点区别,

  1. 直接输入的是 embedding,而不是token,
  2. if top_ids > self.speech_token_size: 假如这次decode的结果不符合要求,就复用上次的输入
  3. 每次输入都是通过输出id,取好的embedding

@irexyc
Copy link
Collaborator

irexyc commented Dec 27, 2024

不确定能不能做。

decode 返回的是logits,而不是hidden_states。你代码里面的 llm_decoder 似乎不是 model 里面的 lm_head?如果可以 llm_decoder 可以替换 model 里面的 lm_head 的话,pipeliine.get_logits 就相当于 self.llm.forward_one_stepself.llm_decoder(y_pred[:, -1])

关于decode 传 embedding 格式大概是下面这个样子,含义就是有一串 dummy ids,用0表示,然后根据 input_embedding_ranges 用 input_embeddings 来替代 input_ids 经过 lookup table 得到的特征,其中 input_embedding_ranges 里面的 range是左闭右开的。

input_ids: [0,0,0,...] # list 长度是n
input_embeddings # torch.tensor, shape n x hidden_dim
input_embedding_ranges # [[0, n]]

注意到你的 forward_one_step 其实是有状态的 (cache)。sequence_start / sequence_end 也需要设置一下,对于每个 sample 第一次需要设置 sequence_start=True,sequence_end=False,中间的step都设置为False,最后一个step设置为False,True

@CallmeZhangChenchen
Copy link
Author

感谢 感谢 我将做一个尝试

@CallmeZhangChenchen
Copy link
Author

@irexyc Hi!多谢指点 ,decode 代码通了

    tm_model = tm.TurboMind.from_pretrained(model_path)
    generator = tm_model.create_instance()
    import numpy as np
    input_embeddings = [np.squeeze(np.load('input0.npy'))] # (1, 107, 896)
    input_ids = [0] * 107
    input_embedding_ranges = [[0, 107]]
    logits = generator.decode(input_ids, steps=[0], input_embeddings=input_embeddings,input_embedding_ranges=input_embedding_ranges, sequence_start=True, sequence_end=False)

但是发现了个问题

class Qwen2ForCausalLM:
       def forward():

1165            outputs = self.model(
1166                input_ids=input_ids,
1167                attention_mask=attention_mask,
1168                position_ids=position_ids,
1169                past_key_values=past_key_values,
1170                inputs_embeds=inputs_embeds,
1171                use_cache=use_cache,
1172                output_attentions=output_attentions,
1173                output_hidden_states=output_hidden_states,
1174                return_dict=return_dict,
1175                cache_position=cache_position,
1176            )
1177 
1178            hidden_states = outputs[0]
1181            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1182 

1191            return CausalLMOutputWithPast(
1192                loss=loss,
1193                logits=logits,
1194                past_key_values=outputs.past_key_values,
1195                hidden_states=outputs.hidden_states,
1196                attentions=outputs.attentions,
1197            )

在 transformers/models/qwen2/modeling_qwen2.py 中 计算 logits = self.lm_head(),入参是 outputs[0],所以我认为 LMDeploy 也是这样做的

但是在cosvoice 的代码中,计算logp = self.llm_decoder(),入参是 outputs.hidden_states[-1][:,-1], 这显然跟 LMDeploy 的计算逻辑不一样
所以我想问一下,LMDeploy 有没有返回特定outputs.hidden_states的接口, 实在没有的话,就只能去改源码了

另外我感觉generator.decode()这个函数,入参都是在 cpu 上的,等最终跑起来之后,应该会有大量的传输开销, 从优化层面有什么方法可以避免嘛

@irexyc
Copy link
Collaborator

irexyc commented Dec 27, 2024

lmdeploy decode 接口返回的是 LLM 模型最终的输出 (output of lm_head),所以我上面建议你把 llm_decoder 的权重放到 model 的 lm_head里面,这样你用 lmdeploy decode 就相当于做了你上面的 self.llm.forward_one_step 和 self.llm_decoder(y_pred[:, -1])

如果你的 lm_input 本身是gpu tensor,想避免传输开销的话,可以试试 pytorch backend,用法类似。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants