fix: support int64 IdType for RoPE part argument in rope_quantize_fp8_append_paged_kv_cache#2255
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughMoves the pos_id dtype dispatch outwards and splits a single Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (3)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (7)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @elvischenv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a type mismatch issue within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly separates RoPEIdType from PagedKVIdType to enable uint64 support for pos_ids, which is a good improvement. However, I've identified a couple of areas for improvement. Firstly, the PagedKVIdType is hardcoded to int32_t for various paged KV cache indices, while the pull request description suggests it should be uint32_t. This could lead to correctness issues. Secondly, the new functionality for 64-bit pos_ids is not covered by automated tests. Adding tests would ensure the feature works as expected and prevent future regressions. I've provided a specific comment to address the type issue.
rope_quantize_fp8_append_paged_kv_cacherope_quantize_fp8_append_paged_kv_cache
|
/bot run |
|
Add some unittest and more type checks in 0fbf196, @elvischenv @kahyunnam let me know if they look good to you. |
|
/bot run |
There was a problem hiding this comment.
To summarize this:
- We have the following integer arguments in the API:
- RoPE part integer argument:
pos_ids - KV update part integer arguments:
batch_indices,positions,kv_indices,kv_indptr
- RoPE part integer argument:
- Before the PR, KV update part integer arguments are always using
int32so thatpos_idswill also be limited inint32, because they are using the same template idtype - KV update part integer arguments are fine for
int32type, butpos_idsmay beint64from the framework. So we need to separate the RoPE part integer type and KV update part integer type. - Actually this PR just wants to extend the idtype support for only
pos_ids.- For the KV update part, it can also be done but should be optional.
csrc/rope.cu
Outdated
| // Validate that all index tensors have the same dtype as pos_ids | ||
| if (kv_indices.dtype() != pos_ids.dtype()) { | ||
| TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indices dtype (" << kv_indices.dtype() | ||
| << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; | ||
| } | ||
| if (kv_indptr.dtype() != pos_ids.dtype()) { | ||
| TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indptr dtype (" << kv_indptr.dtype() | ||
| << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; | ||
| } | ||
| if (batch_indices.dtype() != pos_ids.dtype()) { | ||
| TVM_FFI_LOG_AND_THROW(TypeError) << "batch_indices dtype (" << batch_indices.dtype() | ||
| << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; | ||
| } | ||
| if (positions.dtype() != pos_ids.dtype()) { | ||
| TVM_FFI_LOG_AND_THROW(TypeError) << "positions dtype (" << positions.dtype() | ||
| << ") must match pos_ids dtype (" << pos_ids.dtype() << ")"; | ||
| } | ||
|
|
There was a problem hiding this comment.
@yzh119 This won't work.
pos_ids is for RoPE part:
Lines 1298 to 1314 in df82616
batch_indices, positions, kv_indices, kv_indptr are for KV cache update part:
Lines 255 to 265 in df82616
rope_quantize_fp8_append_paged_kv_cache is a merged version that has pos_ids, batch_indices, positions, kv_indices, kv_indptr:
Lines 1438 to 1460 in df82616
- From my observation from SGLang, RoPE and KV update part can have different integer types.
- That is, for RoPE part, we need a standalone IdType
typename RoPEIdTypeforpos_ids. For KV update part, we need another standalone IdTypetypename PagedKVIdTypeforbatch_indices,positions,kv_indices,kv_indptr. - So that we could support the combination like
pos_idsin int64, andbatch_indices,positions,kv_indices,kv_indptrin int32.
csrc/rope.cu
Outdated
| auto kpe_strides = kpe_cache.strides(); | ||
|
|
||
| paged_kv_mla_t<c_quant_type, int32_t> paged_kv_mla( | ||
| paged_kv_mla_t<c_quant_type, c_idtype> paged_kv_mla( |
There was a problem hiding this comment.
- For RoPE part, I have added a
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype()...for dispatching the idtype for RoPE part integer type. - If we also want to dispatch a idtype for KV update part, we need another nest dispatcher like
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(kv_indptr.dtype()...for integer type in KV update code path.
|
[SUCCESS] Pipeline #40605057: 11/20 passed |
elvischenv
left a comment
There was a problem hiding this comment.
@yzh119 I pushed a commit that make the int64 support only applied on RoPE part argument(pos_ids), also updated the unit tests. Though I don't see the requirements for using int64 KV idtype, we could support them in follow-up PR when we need it.
|
@elvischenv thank you for the clarification. |
📌 Description
rope_quantize_fp8_append_paged_kv_cacheis a merged API ofrope_quantizeandappend_paged_kv_cache(#2037). However,typename IdTypefromRopeQuantizeandAppendPagedKVCacheshould not be merged into the same one since they could be in different dtype.AppendPagedKVCache'sIdTypeis hardcoded toint32butRopeQuantize'sIdTypemay beint64in frameworks.This PR splits
typename IdTypeinto separatedtypename RoPEIdType, typename PagedKVIdType, and this will fix the accuracy issue when passing int64pos_ids(RoPE part argument that withRoPEIdTypetype) to API.cc @kahyunnam @yzh119
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.