-
Notifications
You must be signed in to change notification settings - Fork 282
fix: MTP in chunked prefill mode #1079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! 此拉取请求旨在解决分块预填充模式下多 token 预测 (MTP) 推理的准确性问题。当处理长序列并将其分割成多个块时,用于填充草稿模型 KV 缓存的 Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
本次 PR 修复了在 chunked prefill 模式下 MTP(Medusa-like Token-level Parallelism)草稿模型(draft model)的 KV 缓存填充问题。通过在 ModelInput 中添加下一个 chunk 的首个 token ID,确保了在处理非末尾 chunk 时,草稿模型能够使用正确的 token 来填充其 KV 缓存,从而修正了 MTP 推理的逻辑。
整体实现思路清晰,解决了所述问题。代码修改分布在多个文件中,以支持新字段的数据流和使用。
我发现部分代码逻辑在多个地方重复出现,建议进行重构以提高代码的可维护性。具体细节请见文件中的评论。
| b_has_out = torch.tensor(b_has_out_cpu, dtype=torch.bool, device="cuda") | ||
| b_chunked_next_token_ids = torch.tensor( | ||
| model_input.b_chunked_prefill_next_token_ids_cpu[0:req_num], dtype=torch.int64, device="cuda" | ||
| ) | ||
| mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids) | ||
| draft_next_token_ids_gpu[0:req_num].copy_(mtp_next_token_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分用于计算 mtp_next_token_ids 的逻辑与 chunked_prefill/impl.py 中的 prefill_mtp 方法以及本文件中的 prefill_overlap_mtp 方法(针对 micro_batch0 和 micro_batch1)存在重复。
为了提高代码的可维护性和复用性,建议将这部分逻辑提取为一个独立的辅助函数(例如,可以放在 ModeBackend 基类中)。
例如,可以创建一个这样的辅助函数:
def _get_mtp_next_token_ids(self, b_prefill_has_output_cpu, b_chunked_prefill_next_token_ids_cpu, next_token_ids_gpu):
b_has_out = torch.tensor(b_prefill_has_output_cpu, dtype=torch.bool, device="cuda")
b_chunked_next_token_ids = torch.tensor(
b_chunked_prefill_next_token_ids_cpu, dtype=torch.int64, device="cuda"
)
mtp_next_token_ids = torch.where(b_has_out, next_token_ids_gpu, b_chunked_next_token_ids)
return mtp_next_token_ids这样,当前位置和 prefill_overlap_mtp 中的重复代码都可以通过调用这个新函数来简化。
在 chunked prefill 模式下,当一个长序列被分成多个 chunck 处理时,用来来填充 draft model 的 kv cache 的 next_token_ids 可能并不正确,在 ModelInput 里面添加下一个 chunk 的首个 id 来辅助 mtp 推理。