Skip to content

Conversation

@sufubao
Copy link
Collaborator

@sufubao sufubao commented Oct 14, 2025

在 chunked prefill 模式下,当一个长序列被分成多个 chunck 处理时,用来来填充 draft model 的 kv cache 的 next_token_ids 可能并不正确,在 ModelInput 里面添加下一个 chunk 的首个 id 来辅助 mtp 推理。

@gemini-code-assist
Copy link

Summary of Changes

Hello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

此拉取请求旨在解决分块预填充模式下多 token 预测 (MTP) 推理的准确性问题。当处理长序列并将其分割成多个块时,用于填充草稿模型 KV 缓存的 next_token_ids 可能不正确。通过在 ModelInput 中引入一个新字段来存储下一个块的首个 token ID,并相应地调整 MTP 推理逻辑,确保了在分块预填充场景下草稿模型的 KV 缓存能够被正确填充,从而提高了推理的准确性。

Highlights

  • 新增 ModelInput 字段: 在 ModelInput 类中新增 b_chunked_prefill_next_token_ids_cpu 字段,用于存储分块预填充模式下下一个块的首个 token ID,以辅助 MTP 推理。
  • 修改 get_chuncked_input_token_ids 返回值: infer_batch.py 中的 get_chuncked_input_token_ids 方法现在返回当前块的 token ID 列表以及下一个块的首个 token ID,为 MTP 修正提供必要信息。
  • MTP 推理逻辑修正: 在 chunked_prefilldp_backend 模式的 prefill_mtp 方法中,根据请求是否已完成输出,选择使用常规的 next_token_ids 或新引入的 b_chunked_prefill_next_token_ids_cpu 来填充草稿模型的 KV 缓存,从而修正了分块预填充模式下的 MTP 推理问题。
  • 预处理逻辑更新: generic_padded_pre_process.pygeneric_pre_process.py 中的预处理逻辑已更新,以正确收集下一个块的首个 token ID 并将其传递给 ModelInput
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

本次 PR 修复了在 chunked prefill 模式下 MTP(Medusa-like Token-level Parallelism)草稿模型(draft model)的 KV 缓存填充问题。通过在 ModelInput 中添加下一个 chunk 的首个 token ID,确保了在处理非末尾 chunk 时,草稿模型能够使用正确的 token 来填充其 KV 缓存,从而修正了 MTP 推理的逻辑。
整体实现思路清晰,解决了所述问题。代码修改分布在多个文件中,以支持新字段的数据流和使用。
我发现部分代码逻辑在多个地方重复出现,建议进行重构以提高代码的可维护性。具体细节请见文件中的评论。

Comment on lines 357 to 362
b_has_out = torch.tensor(b_has_out_cpu, dtype=torch.bool, device="cuda")
b_chunked_next_token_ids = torch.tensor(
model_input.b_chunked_prefill_next_token_ids_cpu[0:req_num], dtype=torch.int64, device="cuda"
)
mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids)
draft_next_token_ids_gpu[0:req_num].copy_(mtp_next_token_ids)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分用于计算 mtp_next_token_ids 的逻辑与 chunked_prefill/impl.py 中的 prefill_mtp 方法以及本文件中的 prefill_overlap_mtp 方法(针对 micro_batch0micro_batch1)存在重复。

为了提高代码的可维护性和复用性,建议将这部分逻辑提取为一个独立的辅助函数(例如,可以放在 ModeBackend 基类中)。

例如,可以创建一个这样的辅助函数:

def _get_mtp_next_token_ids(self, b_prefill_has_output_cpu, b_chunked_prefill_next_token_ids_cpu, next_token_ids_gpu):
    b_has_out = torch.tensor(b_prefill_has_output_cpu, dtype=torch.bool, device="cuda")
    b_chunked_next_token_ids = torch.tensor(
        b_chunked_prefill_next_token_ids_cpu, dtype=torch.int64, device="cuda"
    )
    mtp_next_token_ids = torch.where(b_has_out, next_token_ids_gpu, b_chunked_next_token_ids)
    return mtp_next_token_ids

这样,当前位置和 prefill_overlap_mtp 中的重复代码都可以通过调用这个新函数来简化。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants