Skip to content

fix: support int64 IdType for RoPE part argument in rope_quantize_fp8_append_paged_kv_cache#2255

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
elvischenv:elvischenv/fix-rope-idtype
Dec 24, 2025
Merged

fix: support int64 IdType for RoPE part argument in rope_quantize_fp8_append_paged_kv_cache#2255
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
elvischenv:elvischenv/fix-rope-idtype

Conversation

@elvischenv
Copy link
Copy Markdown
Contributor

@elvischenv elvischenv commented Dec 22, 2025

📌 Description

rope_quantize_fp8_append_paged_kv_cache is a merged API of rope_quantize and append_paged_kv_cache(#2037). However, typename IdType from RopeQuantize and AppendPagedKVCache should not be merged into the same one since they could be in different dtype. AppendPagedKVCache's IdType is hardcoded to int32 but RopeQuantize's IdType may be int64 in frameworks.

This PR splits typename IdType into separated typename RoPEIdType, typename PagedKVIdType, and this will fix the accuracy issue when passing int64 pos_ids(RoPE part argument that with RoPEIdType type) to API.

cc @kahyunnam @yzh119

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor
    • Generalized position-encoding and paged KV cache handling to support multiple integer identifier dtypes and improve type consistency.
  • Bug Fixes
    • Enforced/validated consistent integer dtype for index tensors before processing to reduce dtype-mismatch errors.
  • Tests
    • Expanded tests to cover different integer index dtypes (e.g., int32 and int64) for ROPE and paged KV workflows.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 22, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Moves the pos_id dtype dispatch outwards and splits a single IdType into RoPEIdType (pos_ids) and PagedKVIdType (batch_indices/positions) across RoPE kernels and host wrappers; updates CUDA dispatching and tests to parametrize index dtypes. (48 words)

Changes

Cohort / File(s) Summary
CUDA dispatch / runtime
csrc/rope.cu
Moves the pos_id dtype dispatch to wrap the MLA vs GQA/MHA branching, changes inner casts to use the dispatched c_idtype, and adjusts control flow to return from the dispatched lambda while preserving status checks and error messages.
Kernel templates & host wrappers
include/flashinfer/pos_enc.cuh
Replaces IdType with RoPEIdType (for pos_ids) and PagedKVIdType (for batch_indices/positions) in kernel templates and host wrapper signatures; propagates new template args to all paged-KV / MLA instantiations and DISPATCH paths.
Python API callsites
flashinfer/rope.py
Forces/ensures integer dtype for index inputs by casting batch_indices, positions, kv_indices, and kv_indptr to int before calling the underlying op.
Tests: dtype parametrization
tests/attention/test_rope.py
Parameterizes tests with rope_idtype (e.g., torch.int32/torch.int64) and updates index tensors and test call sites to use the chosen integer dtype.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • kahyunnam
  • yzh119
  • nvmbreughe
  • jiahanc

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately reflects the main change: splitting IdType into separate types for RoPE and paged KV cache to support int64 pos_ids in RopeQuantize while keeping paged KV cache at int32.
Description check ✅ Passed The description clearly explains the technical issue and solution. The pre-commit checklist is marked complete; however, the testing checklist indicates tests have not been confirmed as added or passing.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0fbf196 and 0903058.

📒 Files selected for processing (3)
  • csrc/rope.cu
  • flashinfer/rope.py
  • tests/attention/test_rope.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
csrc/rope.cu (1)

550-620: Dispatch and type-casting correctly implements decoupled IdTypes.

The refactored dispatch logic properly separates:

  • c_idtype (from pos_ids.dtype()) for RoPE position IDs — supports both int32 and int64
  • Hardcoded int32_t for paged KV cache indices (kv_indices, kv_indptr, batch_indices, positions)

This aligns with the PR objective to allow int64 pos_ids while keeping paged KV cache indices as int32, and is consistent with the Python-side .int() casts in rope.py.

flashinfer/rope.py (1)

1644-1648: Explicit int32 casts ensure consistency with C++ kernel expectations.

These casts guarantee that batch_indices, positions, kv_indices, and kv_indptr are int32, matching the hardcoded int32_t* casts in the CUDA kernel. This is correct and ensures type safety at the Python/C++ boundary.

tests/attention/test_rope.py (5)

519-532: Test parameterization correctly covers both int32 and int64 RoPE IdTypes.

The new rope_idtype parameter ensures comprehensive test coverage for the decoupled IdType support introduced in this PR.


594-594: Correct usage of rope_idtype for pos_ids.

The pos_ids tensor is created with the parameterized rope_idtype, properly exercising both int32 and int64 code paths through the kernel.


838-852: Consistent test parameterization for decode scenario.

The decode/continuation test also includes the rope_idtype parameter, ensuring the int64 path is tested in both prefill and decode scenarios.


944-946: Correct dtype propagation for existing tokens pos_ids.


1134-1139: Correct dtype propagation for new tokens pos_ids.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @elvischenv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a type mismatch issue within the rope_quantize_fp8_append_paged_kv_cache function, which combines Rotary Positional Embedding (RoPE) and paged KV cache operations. Previously, a single IdType was used for both, leading to potential accuracy problems when pos_ids (a RoPE argument) required a uint64 type while paged KV cache indices were hardcoded to uint32. The changes introduce distinct template parameters, RoPEIdType and PagedKVIdType, to ensure proper type handling for each component, thereby fixing the accuracy issue and improving the robustness of the combined API.

Highlights

  • Type Safety for RoPE pos_ids: The rope_quantize_fp8_append_paged_kv_cache function now correctly handles uint64 pos_ids for Rotary Positional Embeddings, resolving previous accuracy issues.
  • Separated ID Types: The single IdType template parameter has been split into RoPEIdType for RoPE arguments and PagedKVIdType for paged KV cache indices, allowing for distinct data types where needed.
  • Dynamic Type Dispatch: The implementation in csrc/rope.cu now utilizes DISPATCH_DLPACK_IDTYPE_TO_CTYPE to dynamically determine the appropriate C++ type for pos_ids at runtime, ensuring type correctness.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly separates RoPEIdType from PagedKVIdType to enable uint64 support for pos_ids, which is a good improvement. However, I've identified a couple of areas for improvement. Firstly, the PagedKVIdType is hardcoded to int32_t for various paged KV cache indices, while the pull request description suggests it should be uint32_t. This could lead to correctness issues. Secondly, the new functionality for 64-bit pos_ids is not covered by automated tests. Adding tests would ensure the feature works as expected and prevent future regressions. I've provided a specific comment to address the type issue.

@elvischenv elvischenv changed the title fix: support uint64 IdType for RoPE part argument in rope_quantize_fp8_append_paged_kv_cache fix: support int64 IdType for RoPE part argument in rope_quantize_fp8_append_paged_kv_cache Dec 22, 2025
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Dec 22, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !210 has been created, and the CI pipeline #40603791 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Dec 22, 2025

Add some unittest and more type checks in 0fbf196, @elvischenv @kahyunnam let me know if they look good to you.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Dec 22, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !210 has been updated with latest changes, and the CI pipeline #40605057 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor Author

@elvischenv elvischenv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize this:

  • We have the following integer arguments in the API:
    • RoPE part integer argument: pos_ids
    • KV update part integer arguments: batch_indices, positions, kv_indices, kv_indptr
  • Before the PR, KV update part integer arguments are always using int32 so that pos_ids will also be limited in int32, because they are using the same template idtype
  • KV update part integer arguments are fine for int32 type, but pos_ids may be int64 from the framework. So we need to separate the RoPE part integer type and KV update part integer type.
  • Actually this PR just wants to extend the idtype support for only pos_ids.
    • For the KV update part, it can also be done but should be optional.

csrc/rope.cu Outdated
Comment on lines +453 to +470
// Validate that all index tensors have the same dtype as pos_ids
if (kv_indices.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indices dtype (" << kv_indices.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}
if (kv_indptr.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "kv_indptr dtype (" << kv_indptr.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}
if (batch_indices.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "batch_indices dtype (" << batch_indices.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}
if (positions.dtype() != pos_ids.dtype()) {
TVM_FFI_LOG_AND_THROW(TypeError) << "positions dtype (" << positions.dtype()
<< ") must match pos_ids dtype (" << pos_ids.dtype() << ")";
}

Copy link
Copy Markdown
Contributor Author

@elvischenv elvischenv Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 This won't work.
pos_ids is for RoPE part:

flashinfer/flashinfer/rope.py

Lines 1298 to 1314 in df82616

def rope_quantize_fp8(
q_rope: torch.Tensor,
k_rope: torch.Tensor,
q_nope: Optional[torch.Tensor],
k_nope: Optional[torch.Tensor],
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
is_neox: bool = True,
quantize_dtype: Optional[torch.dtype] = None,
quant_scale_q: float = 1.0,
quant_scale_kv: float = 1.0,
q_rope_out: Optional[torch.Tensor] = None,
k_rope_out: Optional[torch.Tensor] = None,
q_nope_out: Optional[torch.Tensor] = None,
k_nope_out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

batch_indices, positions, kv_indices, kv_indptr are for KV cache update part:

def append_paged_kv_cache(
append_key: torch.Tensor,
append_value: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
kv_last_page_len: torch.Tensor,
kv_layout: str = "NHD",
) -> None:

rope_quantize_fp8_append_paged_kv_cache is a merged version that has pos_ids, batch_indices, positions, kv_indices, kv_indptr:

flashinfer/flashinfer/rope.py

Lines 1438 to 1460 in df82616

def rope_quantize_fp8_append_paged_kv_cache(
q_rope: torch.Tensor,
k_rope: torch.Tensor,
q_nope: Optional[torch.Tensor],
k_nope: Optional[torch.Tensor],
v: Optional[torch.Tensor],
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
paged_kv_cache: Tuple[torch.Tensor, torch.Tensor],
kv_indices: torch.Tensor,
kv_indptr: torch.Tensor,
batch_indices: torch.Tensor,
positions: torch.Tensor,
is_neox: bool = True,
quantize_dtype: Optional[torch.dtype] = None,
quant_scale_q: float = 1.0,
quant_scale_kv: float = 1.0,
page_size: int = 16,
kv_layout: str = "NHD",
q_rope_out: Optional[torch.Tensor] = None,
q_nope_out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:

  • From my observation from SGLang, RoPE and KV update part can have different integer types.
  • That is, for RoPE part, we need a standalone IdType typename RoPEIdType for pos_ids. For KV update part, we need another standalone IdType typename PagedKVIdType for batch_indices, positions, kv_indices, kv_indptr.
  • So that we could support the combination like pos_ids in int64, and batch_indices, positions, kv_indices, kv_indptr in int32.

csrc/rope.cu Outdated
Comment on lines +568 to +576
auto kpe_strides = kpe_cache.strides();

paged_kv_mla_t<c_quant_type, int32_t> paged_kv_mla(
paged_kv_mla_t<c_quant_type, c_idtype> paged_kv_mla(
Copy link
Copy Markdown
Contributor Author

@elvischenv elvischenv Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For RoPE part, I have added a DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype()... for dispatching the idtype for RoPE part integer type.
  • If we also want to dispatch a idtype for KV update part, we need another nest dispatcher like DISPATCH_DLPACK_IDTYPE_TO_CTYPE(kv_indptr.dtype()... for integer type in KV update code path.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #40605057: 11/20 passed

Copy link
Copy Markdown
Contributor Author

@elvischenv elvischenv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 I pushed a commit that make the int64 support only applied on RoPE part argument(pos_ids), also updated the unit tests. Though I don't see the requirements for using int64 KV idtype, we could support them in follow-up PR when we need it.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Dec 24, 2025

@elvischenv thank you for the clarification.

@yzh119 yzh119 merged commit 0ccf4e3 into flashinfer-ai:main Dec 24, 2025
4 checks passed
@elvischenv elvischenv deleted the elvischenv/fix-rope-idtype branch December 24, 2025 07:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants