Skip to content

[flashinfer.rope] refactor CUDA backend, add CuTe-DSL backend for unfused RoPE APIs#2470

Open
kahyunnam wants to merge 10 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/RoPE-cute-dsl
Open

[flashinfer.rope] refactor CUDA backend, add CuTe-DSL backend for unfused RoPE APIs#2470
kahyunnam wants to merge 10 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/RoPE-cute-dsl

Conversation

@kahyunnam
Copy link
Copy Markdown
Collaborator

@kahyunnam kahyunnam commented Feb 3, 2026

📌 Description

This PR refactors the RoPE API in general, and also adds CuTe-DSL backend option for the unfused RoPE APIs (flashinfer.rope, apply_rope, apply_rope_inplace, apply_rope_pos_ids, apply_rope_pos_ids_inplace, apply_llama31_rope, apply_llama31_rope_inplace, apply_llama31_rope_pos_ids, apply_llama31_rope_pos_ids_inplace, apply_rope_with_cos_sin_cache, apply_rope_with_cos_sin_cache_inplace)
Refactored structure:

# **Python API Layer:**
flashinfer/rope/
├── rope.py                    # Public API layer with backend dispatch
├── custom_ops.py              # CUDA C++ backend Python bindings (TVM FFI)
├── kernels/                   # CuTe-DSL backend implementation
│   ├── kernels.py             # Core CuTe-DSL kernel implementations
│   ├── wrappers.py            # Python wrapper functions for CuTe-DSL
│   ├── compile.py             # CuTe-DSL kernel compilation and caching
│   ├── ptx_ops.py             # PTX intrinsics (math, memory ops)
│   └── helpers.py             # Utility functions
└── __init__.py                # Module exports

# **CUDA C++ Backend:**
csrc/
└── rope.cu                    # TVM FFI bindings for CUDA RoPE kernels

include/flashinfer/
├── pos_enc.cuh                # Inline RoPE helper functions for fused attention kernels
└── rope/                      # Standalone RoPE CUDA kernel implementations
    ├── types.cuh              # Type definitions and enums
    ├── kernels.cuh            # Core CUDA kernel implementations
    ├── launchers.cuh          # Kernel launcher functions with dispatch logic
    └── pos_enc_kernels.cuh   # Positional encoding kernel utilities

🚗 Performance Analysis

Overall summary: good perf improvements for decode cases (especially llama31 related functions). Some problem sizes are a bit slower; can be investigated later (the default backend is still cuda for now).

Example benchmarking summary for B200 (benchmarking matrix averaged for MHA, MLA, GQA configurations, decode batch 1 - 2048, prefill seq 1k - 32k):

CuTe-DSL speedups relative to CUDA: 

apply_rope: min=1.00x, max=1.78x, avg=1.23x
apply_rope_inplace: min=0.97x, max=1.69x, avg=1.18x
apply_rope_pos_ids: min=0.91x, max=1.53x, avg=1.27x
apply_rope_pos_ids_inplace: min=0.94x, max=1.41x, avg=1.16x
apply_llama31_rope: min=0.86x, max=1.78x, avg=1.23x
apply_llama31_rope_inplace: min=0.97x, max=1.69x, avg=1.18x
apply_llama31_rope_pos_ids: min=0.91x, max=15.00x, avg=3.83x
apply_llama31_rope_pos_ids_inplace: min=0.94x, max=5.15x, avg=1.89x
apply_rope_with_cos_sin_cache: min=0.90x, max=1.25x, avg=1.09x
apply_rope_with_cos_sin_cache_inplace: min=0.94x, max=1.18x, avg=1.04x

B300 benchmarking
B200 benchmarking
H100 benchmarking
A100 benchmarking

For source script, see benchmarks/bench_rope_workloads.py.

✅ Testing

pytest -x tests/attention/test_rope.py

====================================================================== 53290 passed in 490.33s (0:08:10) =======================================================================

🔍 Related Issues

#2491

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added CuTe-DSL backend support for RoPE operations via new backend parameter ("cuda" or "cute-dsl") across all RoPE functions.
    • Created dedicated flashinfer.rope module with comprehensive RoPE API consolidation.
    • Added benchmark tooling for comparing backend performance.
    • Extended RoPE support for cos/sin cache variants and quantization workflows.

@kahyunnam kahyunnam marked this pull request as draft February 3, 2026 01:03
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 3, 2026

📝 Walkthrough

Walkthrough

This PR introduces comprehensive CuTe-DSL backend support for RoPE (Rotary Positional Embeddings) operations, adding a new backend parameter to multiple RoPE-related functions and implementing kernel variants for the CuTe-DSL path alongside existing CUDA kernels, along with a new benchmark suite.

Changes

Cohort / File(s) Summary
RoPE Public API Layer
flashinfer/rope.py, flashinfer/rope/__init__.py, flashinfer/rope/utils.py
Introduces public-facing RoPE functions with backend parameter supporting "cuda" and "cute-dsl"; routes operations to appropriate implementations; includes availability probes and module loading utilities.
RoPE Custom Ops Registration
flashinfer/rope/custom_ops.py
Registers PyTorch custom ops for RoPE variants with fake op counterparts; covers standard RoPE, Llama 3.1 RoPE, position-ID variants, cos-sin cache variants, and quantization operations.
CuTe-DSL Kernel Infrastructure
flashinfer/rope/kernels/__init__.py, flashinfer/rope/kernels/compile.py, flashinfer/rope/kernels/kernels.py, flashinfer/rope/kernels/ptx_ops.py, flashinfer/rope/kernels/helpers.py, flashinfer/rope/kernels/wrappers.py
Comprehensive CuTe-DSL RoPE kernel implementation stack including low-level PTX intrinsics, kernel class definitions with interleaved/non-interleaved variants, cached JIT compilation, and high-level wrappers with kernel selection logic.
CUDA RoPE Kernels and Launchers
include/flashinfer/rope/kernels.cuh, include/flashinfer/rope/launchers.cuh, include/flashinfer/rope/types.cuh, include/flashinfer/rope/pos_enc_kernels.cuh
New C++ CUDA headers providing kernel implementations for various RoPE strategies (head-parallel, sequential-heads, cos-sin cache), quantization paths, paged KV cache integration, and host-side launchers; includes type definitions for complex parameter structs.
Refactored Positional Encoding
include/flashinfer/pos_enc.cuh, csrc/rope.cu
Removes large kernel declaration surface area and replaces with inline helpers (vec_apply_llama_rope variants, scale_store_partial_chunk); adds PosEncodingModeToString and get_alibi_slope utilities; updates include directive.
Testing
tests/attention/test_rope.py
Extends existing RoPE tests with backend parameter and conditional skipping for unavailable CuTe-DSL backend.
Benchmarking
benchmarks/bench_rope_workloads.py
New comprehensive benchmark script with configurable workloads, dimension sweeps, and per-backend timing/speedup reporting for RoPE operations across CUDA and CuTe-DSL backends.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

benchmark, feature, cuda, kernel

Suggested reviewers

  • kaixih
  • aleozlx
  • yzh119
  • cyx-6
  • bkryu
  • nvmbreughe
  • jimmyzho

Poem

🐰 A rope of light, now spins both ways,
CuTe-DSL and CUDA's praise!
Kernels cascade in PTX glow,
Benchmarks dance, watch speedups flow! 🚀

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.65% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (15 files):

⚔️ csrc/flashinfer_sampling_binding.cu (content)
⚔️ csrc/rope.cu (content)
⚔️ csrc/sampling.cu (content)
⚔️ flashinfer/aot.py (content)
⚔️ flashinfer/gemm/gemm_base.py (content)
⚔️ flashinfer/rope.py (content)
⚔️ flashinfer/sampling.py (content)
⚔️ flashinfer/triton/__init__.py (content)
⚔️ flashinfer/utils.py (content)
⚔️ include/flashinfer/pos_enc.cuh (content)
⚔️ include/flashinfer/sampling.cuh (content)
⚔️ scripts/task_run_unit_tests.sh (content)
⚔️ scripts/test_utils.sh (content)
⚔️ tests/attention/test_rope.py (content)
⚔️ tests/gemm/test_bmm_fp8.py (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely summarizes the main objective: refactoring the CUDA backend and adding a CuTe-DSL backend for unfused RoPE APIs.
Description check ✅ Passed PR description covers objectives, related issues, performance analysis, testing results, and pre-commit checklist status comprehensively.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
⚔️ Resolve merge conflicts (beta)
  • Auto-commit resolved conflicts to branch knam/RoPE-cute-dsl
  • Post resolved changes as copyable diffs in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @kahyunnam, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new, highly optimized backend for Rotary Positional Embeddings (RoPE) leveraging the CuTe-DSL framework. This refactoring provides an alternative implementation to the existing CUDA C++ kernels, supporting both interleaved and non-interleaved RoPE styles, along with Llama 3.1 frequency scaling. The changes integrate this new backend into the public API, allowing users to select their preferred implementation, and include comprehensive unit tests to ensure correctness.

Highlights

  • New CuTe-DSL Backend for RoPE: A complete implementation of Rotary Positional Embeddings (RoPE) using the CuTe-DSL framework has been introduced, offering an alternative to the existing CUDA C++ backend.
  • Support for Interleaved and Non-Interleaved RoPE: The new CuTe-DSL kernels are optimized to support both GPT-J (interleaved) and NeoX (non-interleaved) RoPE styles, utilizing 128-bit vectorized memory access for enhanced efficiency.
  • Llama 3.1 Frequency Scaling: The CuTe-DSL implementation incorporates support for Llama 3.1's advanced frequency scaling, enabling smooth interpolation between scaled and unscaled frequencies.
  • Backend Selection in Public API: Public API functions within the flashinfer.rope module (e.g., apply_rope, apply_llama31_rope) now include a backend parameter, allowing users to explicitly choose between the 'cuda' (default) and 'cute-dsl' implementations.
  • Comprehensive PTX Intrinsics: The CuTe-DSL implementation leverages direct PTX intrinsics for approximate trigonometric functions (sin/cos), vectorized memory operations, and half2/bfloat2 conversions, providing fine-grained control and performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/cute_dsl/rope.py
    • Added a new module for CuTe-DSL based RoPE kernels.
    • Implemented PTX intrinsics for approximate math functions (e.g., sin_approx, cos_approx, sincos_approx).
    • Implemented PTX intrinsics for 128-bit vectorized global memory loads/stores.
    • Implemented PTX intrinsics for half2/bfloat2 to float2 conversions and vice-versa.
    • Defined RopeKernelNonInterleavedVec and RopeKernelInterleavedVec classes for NeoX and GPT-J style RoPE, respectively, with vectorized optimizations.
    • Included a kernel caching mechanism (_get_compiled_kernel) for efficient kernel compilation.
    • Exposed public API functions (apply_rope_cute_dsl, apply_rope_with_indptr_cute_dsl, apply_llama31_rope_with_indptr_cute_dsl) for applying RoPE with CuTe-DSL.
    • Added a helper function _compute_pos_ids_from_indptr_offsets to derive position IDs from indptr and offsets.
  • flashinfer/rope.py
    • Added Literal import from typing.
    • Introduced _is_cute_dsl_available() and _get_cute_dsl_rope_kernel() helper functions.
    • Modified apply_rope, apply_rope_pos_ids, apply_llama31_rope, and apply_llama31_rope_pos_ids functions to accept a new backend parameter ('cuda' or 'cute-dsl').
    • Updated the implementation of these functions to dispatch to the CuTe-DSL backend if specified and available.
    • Updated docstrings for the modified functions to reflect the new backend parameter and clarify input tensor descriptions.
  • tests/attention/test_rope.py
    • Imported is_cute_dsl_available from flashinfer.cute_dsl.utils.
    • Added backend parameter to test_rope and test_rope_pos_ids pytest fixtures.
    • Implemented conditional skipping for CuTe-DSL tests if the backend is not available or if inplace operations are attempted (as CuTe-DSL currently doesn't support them).
    • Passed the backend argument to the flashinfer.apply_rope and flashinfer.apply_llama31_rope calls in the tests.
Activity
  • The pull request is currently a draft and marked as Work In Progress (WIP).
  • The author, kahyunnam, has stated the need to clean up code and benchmark.
  • The PR was opened as a draft to track merge conflicts with ongoing CuTe-DSL refactoring.
  • There is no explicit human activity (comments, reviews) beyond the author's initial submission and description in the provided context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@kahyunnam kahyunnam self-assigned this Feb 3, 2026
@kahyunnam kahyunnam changed the title [wip] Rope kernels: refactor cuda backend to cute dsl [wip] RoPE kernels: refactor cuda backend to cute dsl Feb 3, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CuTe-DSL backend for RoPE kernels as an alternative to the existing CUDA C++ implementation. The changes include a new flashinfer/cute_dsl/rope.py file with the kernel implementations and modifications to flashinfer/rope.py to add a backend switch. The tests are also updated to cover the new backend.

My review focuses on the new CuTe-DSL implementation. I've identified opportunities for performance optimization and code refactoring to improve maintainability. Specifically, I've suggested using a more efficient PTX instruction for sincos, vectorizing a CPU-bound helper function, and refactoring duplicated code within the main kernel loops.

Comment on lines +120 to +154
def sincos_approx(x: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
"""
Compute approximate sin and cos simultaneously using PTX.

This is more efficient than calling sin_approx and cos_approx separately
as it computes both values in a single operation.

Returns (sin_val, cos_val).
"""
# PTX doesn't have a combined sincos.approx, so we compute separately
# but this function serves as documentation that both are needed
# and allows future optimization if PTX adds such instruction
sin_val = Float32(
llvm.inline_asm(
T.f32(),
[Float32(x).ir_value(loc=loc, ip=ip)],
"sin.approx.f32 $0, $1;",
"=f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
cos_val = Float32(
llvm.inline_asm(
T.f32(),
[Float32(x).ir_value(loc=loc, ip=ip)],
"cos.approx.f32 $0, $1;",
"=f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
return sin_val, cos_val
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on line 129 is incorrect. PTX ISA has supported a combined sin/cos approximate instruction (sinc.approx.f32) since SM 7.0, which is more efficient than calling sin.approx.f32 and cos.approx.f32 separately. I suggest refactoring this function to use a single llvm.inline_asm call with sinc.approx.f32 for better performance.

def sincos_approx(x: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
    """
    Compute approximate sin and cos simultaneously using PTX sinc.approx.f32.

    This is more efficient than calling sin_approx and cos_approx separately
    as it computes both values in a single operation.

    Returns (sin_val, cos_val).
    """
    result = llvm.inline_asm(
        llvm.StructType.get_literal([T.f32(), T.f32()]),
        [Float32(x).ir_value(loc=loc, ip=ip)],
        """{
    .reg .v2 .f32 sincos;
    sinc.approx.f32 sincos, $2;
    mov.b32 $0, sincos.x;
    mov.b32 $1, sincos.y;
}""",
        "=f,=f,f",
        has_side_effects=False,
        is_align_stack=False,
        asm_dialect=llvm.AsmDialect.AD_ATT,
        loc=loc,
        ip=ip,
    )
    sin_val = llvm.extractvalue(T.f32(), result, [0], loc=loc, ip=ip)
    cos_val = llvm.extractvalue(T.f32(), result, [1], loc=loc, ip=ip)
    return Float32(sin_val), Float32(cos_val)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/rope.py (1)

771-905: ⚠️ Potential issue | 🟡 Minor

Validate backend values to avoid silent CUDA fallback.

If backend is anything other than "cuda" or "cute-dsl", the code currently falls through to CUDA. Please raise a clear error and apply the same guard in the other backend‑parameterized APIs.

Suggested guard for apply_rope (mirror in other APIs)
     if rotary_dim is None:
         rotary_dim = q.size(-1)
 
     if backend == "cute-dsl":
         if not _is_cute_dsl_available():
             raise RuntimeError(
                 "CuTe-DSL backend is not available. Please install CuTe-DSL."
             )
         from .cute_dsl.rope import apply_rope_with_indptr_cute_dsl
 
         return apply_rope_with_indptr_cute_dsl(
             q,
             k,
             indptr,
             offsets,
             rotary_dim=rotary_dim,
             interleave=interleave,
             rope_scale=rope_scale,
             rope_theta=rope_theta,
         )
 
-    # Default: CUDA C++ backend
+    if backend != "cuda":
+        raise ValueError(f"Unsupported backend: {backend}")
+    # Default: CUDA C++ backend
     _apply_rope(
         q,
         k,
         q_rope,
         k_rope,
         indptr,
         offsets,
         rotary_dim,
         interleave,
         rope_scale,
         rope_theta,
     )
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/rope.py`:
- Around line 1566-1569: The __all__ export list is unsorted and triggers Ruff
RUF022; reorder the entries in the __all__ list alphabetically so
"RopeKernelInterleavedVec" appears before "RopeKernelNonInterleavedVec" (i.e.,
sort the list containing RopeKernelNonInterleavedVec and
RopeKernelInterleavedVec) to satisfy the linter.
- Around line 1322-1330: Add explicit vector-alignment assertions for head_dim
and rotary_dim to match the kernel's 8-element vector and single-pair thread
assumptions: ensure head_dim is a multiple of 8 (replace or augment the existing
head_dim % 2 check) and, depending on the layout mode, ensure rotary_dim is
aligned so the kernel won't straddle halves — specifically assert rotary_dim % 8
== 0 when using interleaved vectors, and assert (rotary_dim // 2) % 8 == 0 in
the non-interleaved/half-pair case; update the validation near the existing
rotary_dim/head_dim checks (the block referencing rotary_dim and head_dim) and
include clear assertion messages referencing rotary_dim and head_dim.

In `@tests/attention/test_rope.py`:
- Around line 47-53: The cute-dsl test gating only checks availability and
inplace support but misses GPU compute-capability guards; update the two
CuTe-DSL skip blocks (the branches checking backend == "cute-dsl" that call
is_cute_dsl_available() and pytest.skip()) to also call the appropriate
flashinfer.utils helpers (e.g., get_compute_capability() and the relevant
predicate like is_sm90a_supported() or a project-specific is_sm_supported(...)
helper) and skip with pytest.skip(...) when the current GPU SM is unsupported;
ensure you import the helpers from flashinfer.utils at the top of the test file
and apply the same change to the other CuTe-DSL skip site mentioned in the
review.
🧹 Nitpick comments (1)
flashinfer/cute_dsl/rope.py (1)

1435-1445: Avoid per‑sequence GPU syncs when building pos_ids.

The Python loop with .item() on CUDA tensors forces a sync per sequence. A vectorized construction keeps this on‑GPU and scales better for large batches.

Vectorized alternative
-    # Create output tensor
-    pos_ids = torch.empty(nnz, dtype=torch.int32, device=device)
-
-    # For each sequence, compute positions
-    # This is a simple CPU loop - could be optimized with a Triton kernel if needed
-    for i in range(batch_size):
-        start = indptr[i].item()
-        end = indptr[i + 1].item()
-        offset = offsets[i].item()
-        seq_len = end - start
-        if seq_len > 0:
-            pos_ids[start:end] = (
-                torch.arange(seq_len, dtype=torch.int32, device=device) + offset
-            )
+    lengths = (indptr[1:] - indptr[:-1]).to(torch.int32)
+    seq_ids = torch.repeat_interleave(
+        torch.arange(batch_size, device=device, dtype=torch.int32), lengths
+    )
+    base = indptr[:-1].to(torch.int32)
+    pos_ids = (
+        torch.arange(nnz, device=device, dtype=torch.int32)
+        - base[seq_ids]
+        + offsets.to(torch.int32)[seq_ids]
+    )

Comment on lines +47 to +53
# Skip cute-dsl backend if not available or for inplace operations
if backend == "cute-dsl":
if not is_cute_dsl_available():
pytest.skip("CuTe-DSL not available")
if inplace:
pytest.skip("CuTe-DSL backend does not support inplace operations")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add compute-capability skips for CuTe‑DSL tests.

The cute‑dsl parametrization only checks availability/inplace; on unsupported SMs these tests can still run and fail. Please gate cute‑dsl runs with flashinfer.utils compute‑capability helpers in both skip blocks.

Suggested guard (adjust helper to the actual CuTe‑DSL support matrix)
 import flashinfer
 from flashinfer.cute_dsl.utils import is_cute_dsl_available
+from flashinfer.utils import get_compute_capability, is_sm90a_supported
 ...
     if backend == "cute-dsl":
         if not is_cute_dsl_available():
             pytest.skip("CuTe-DSL not available")
+        cc = get_compute_capability()
+        if not is_sm90a_supported(cc):
+            pytest.skip(f"CuTe-DSL backend not supported on {cc}")
         if inplace:
             pytest.skip("CuTe-DSL backend does not support inplace operations")

As per coding guidelines: tests/**/*.py: Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.

Also applies to: 177-183

🤖 Prompt for AI Agents
In `@tests/attention/test_rope.py` around lines 47 - 53, The cute-dsl test gating
only checks availability and inplace support but misses GPU compute-capability
guards; update the two CuTe-DSL skip blocks (the branches checking backend ==
"cute-dsl" that call is_cute_dsl_available() and pytest.skip()) to also call the
appropriate flashinfer.utils helpers (e.g., get_compute_capability() and the
relevant predicate like is_sm90a_supported() or a project-specific
is_sm_supported(...) helper) and skip with pytest.skip(...) when the current GPU
SM is unsupported; ensure you import the helpers from flashinfer.utils at the
top of the test file and apply the same change to the other CuTe-DSL skip site
mentioned in the review.

@kahyunnam kahyunnam changed the title [wip] RoPE kernels: refactor cuda backend to cute dsl RoPE kernels: refactor CUDA backend, add CuTe-DSL backend for unfused RoPE Feb 12, 2026
@kahyunnam kahyunnam changed the title RoPE kernels: refactor CUDA backend, add CuTe-DSL backend for unfused RoPE RoPE kernels: refactor CUDA backend, add CuTe-DSL backend for unfused RoPE APIs Feb 12, 2026
@kahyunnam kahyunnam changed the title RoPE kernels: refactor CUDA backend, add CuTe-DSL backend for unfused RoPE APIs [flashinfer.rope] refactor CUDA backend, add CuTe-DSL backend for unfused RoPE APIs Feb 12, 2026
@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !315 has been created, and the CI pipeline #43935806 is currently running. I'll report back once the pipeline job completes.

@kahyunnam kahyunnam marked this pull request as ready for review February 13, 2026 17:59
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🤖 Fix all issues with AI agents
In `@benchmarks/bench_rope_workloads.py`:
- Around line 544-546: The calculation speedup = cuda_ms / cute_ms can raise
ZeroDivisionError if cute_ms is 0; modify the code around where speedup is
computed (variables cuda_ms and cute_ms) to guard against zero or near-zero
cute_ms (e.g., check if cute_ms == 0 or abs(cute_ms) < epsilon) and handle it
deterministically (return float('inf'), None, or a default value and/or log a
warning) so division is never attempted when cute_ms is zero; update any
downstream uses of speedup to account for the chosen sentinel.

In `@flashinfer/rope.py`:
- Around line 874-892: The backend switch silently defaults to the CUDA C++
backend when an unknown backend string is passed; update the backend-selection
logic (the branch handling backend, including the "cute-dsl" case and the
default CUDA path) to explicitly validate supported values and raise a
ValueError for unsupported strings instead of falling through; for example,
check backend against the allowed set (e.g., "cute-dsl", "cuda" or whichever
exact names your APIs expect) before importing/dispatching and raise
ValueError("Unsupported backend: {backend}") if it’s not recognized; apply the
same explicit validation pattern to the other backend-switching APIs mentioned
(the branches around apply_rope_with_indptr_cute_dsl, and the other blocks at
the indicated ranges).

In `@flashinfer/rope/__init__.py`:
- Around line 1-105: Duplicate module conflict: remove the old top-level module
file flashinfer/rope.py so Python imports the new package flashinfer.rope (the
package's __init__.py imports symbols from .rope like apply_rope,
apply_rope_pos_ids, apply_llama31_rope, rope_quantize_fp8, etc.). Delete
flashinfer/rope.py from the repo and CI, ensure no other files import the old
top-level module path, and run the typecheck/test suite to confirm the Duplicate
module error is resolved.

In `@flashinfer/rope/custom_ops.py`:
- Around line 268-279: The fake op _fake_apply_rope_pos_ids_cos_sin_cache has
the wrong parameter list (it currently takes cos_cache and sin_cache
separately); change its signature to match the real op by replacing the two
params with a single cos_sin_cache parameter so it accepts (q, k, q_rope,
k_rope, cos_sin_cache, pos_ids, interleave), update any internal references to
use cos_sin_cache, and ensure the `@register_fake_op` name remains
"flashinfer::apply_rope_pos_ids_cos_sin_cache" to match the real implementation.
- Around line 287-323: The Python wrapper _rope_quantize has the parameters
cos_sin_cache and pos_ids placed before the output tensors, but the underlying
C++ expects those two after the outputs; update the function signature of
_rope_quantize to list q_rope_out, k_rope_out, q_nope_out, k_nope_out before
cos_sin_cache and pos_ids (so the order is q_rope_in, k_rope_in, q_nope_in,
k_nope_in, q_rope_out, k_rope_out, q_nope_out, k_nope_out, cos_sin_cache,
pos_ids, quant_scale_q, quant_scale_kv, interleave, enable_pdl) and then adjust
the call to get_rope_module().rope_quantize to use matching parameter order (you
can then remove the current manual reordering in that call); this change affects
the _rope_quantize function signature and its invocation to match the C++
argument order.

In `@flashinfer/rope/kernels/kernels.py`:
- Around line 71-88: The constructor (__init__) assumes 8-element vector
alignment (elems_per_thread=8) and half-rotary pairs, which breaks when head_dim
% 8 != 0 or rotary_dim % 16 != 0 and can make bdx==0 leading to bdy
divide-by-zero; add explicit validation at the start of __init__ to raise a
clear ValueError if head_dim % 8 != 0 or rotary_dim % 16 != 0, and guard the
bdx/bdy calculation by ensuring bdx = max(1, head_dim // self.elems_per_thread)
and computing num_threads and bdy accordingly (e.g., num_threads = max(128, bdx)
then bdy = max(1, num_threads // bdx)); apply the same validation/guards in the
other non-interleaved kernel variants that use elems_per_thread, bdx,
num_threads, bdy, head_dim, rotary_dim, and half_rotary.

In `@flashinfer/rope/kernels/wrappers.py`:
- Around line 147-169: Add vector-alignment and rotary-dim guards to the
CuTe-DSL wrappers: in apply_rope_cute_dsl and apply_rope_with_indptr_cute_dsl
validate that head_dim is 8-element aligned (e.g., assert head_dim % 8 == 0) and
preserve the existing rotary_dim checks by asserting rotary_dim <= head_dim and
rotary_dim % 2 == 0; additionally, when running in non-interleaved mode enforce
rotary_dim is divisible by 16 (e.g., assert rotary_dim % 16 == 0 when
interleaving flag indicates non-interleaved). Mirror these checks in both the
indptr/cos-sin-cache wrappers so both apply_rope_cute_dsl and
apply_rope_with_indptr_cute_dsl perform the same alignment and rotary-dimension
validation.
- Around line 23-34: The exported CuTe-DSL wrapper functions in this module (the
functions listed in __all__ and the wrapper functions around the compiled
kernels) are missing the `@flashinfer_api` decorator; import flashinfer_api and
add `@flashinfer_api` above each exported wrapper function (including the
functions in the ranges mentioned: 90-103, 312-326, 455-469, 495-503) so they
participate in API logging and crash-safe input capture; ensure you place the
decorator directly above the wrapper function definitions that call the compiled
helpers (_get_compiled_kernel, _get_compiled_kernel_seq_heads,
_get_compiled_kernel_with_indptr, _get_compiled_cos_sin_cache_kernel,
_get_compiled_cos_sin_cache_kernel_seq_heads) and that the import for
flashinfer_api is added to the top-level imports.

In `@flashinfer/rope/rope.py`:
- Around line 402-406: The See Also cross-reference is incorrect: update the
reference in the docstring of apply_rope_pos_ids to point to its true inplace
counterpart apply_rope_pos_ids_inplace (replace apply_rope_inplace with
apply_rope_pos_ids_inplace) so the documentation correctly links
apply_rope_pos_ids to apply_rope_pos_ids_inplace.

In `@include/flashinfer/rope/kernels.cuh`:
- Around line 795-824: RopeQuantizeAppendPagedKVCacheKernel currently does
unchecked cast_load/cast_store in the K-nope and Q-nope paths (uses k_nope_in,
elem_offset, tx, vec_size, cast_load/cast_store and paged_kv_like.get_ckv_ptr /
get_k_ptr), which can read/write past no_rope_dim when no_rope_dim %
rope_chunk_size != 0; change these sections to mirror the non-paged kernel by
computing the per-chunk valid count (chunk_valid = max(0, min(vec_size,
no_rope_dim - (elem_offset + tx*vec_size)))) and use the existing partial-store
helper (scale_store_partial_chunk) or otherwise guard loads/stores by
chunk_valid so that only valid elements are read/written; keep the same
quant_scale_kv scaling before storing and apply the check for both K-nope and
Q-nope branches.

In `@include/flashinfer/rope/launchers.cuh`:
- Around line 675-679: The computation of smooth_a and smooth_b can divide by
zero when high_freq_factor == low_freq_factor; update both Llama‑3.1 launcher
functions to guard this by checking if fabsf(high_freq_factor - low_freq_factor)
< 1e-6f (or similar small epsilon) and in that case set smooth_a = 0.0f and
smooth_b = 0.0f (matching the CuTe‑DSL path); otherwise compute smooth_a and
smooth_b with the existing formulas using high_freq_factor and low_freq_factor.
Ensure you apply the same guard for the other occurrence of smooth_a/smooth_b
(the lines around 740-743) so no division by zero occurs.
🧹 Nitpick comments (3)
benchmarks/bench_rope_workloads.py (1)

195-505: Consider reducing duplication in benchmark_api.

This ~310-line function consists of a long if/elif chain where each branch defines nearly identical run_cuda/run_cute closures differing only in the API name and a few arguments. A data-driven approach (mapping API name → callable + arg builder) would cut this significantly and make adding new APIs a one-liner.

That said, for a benchmark script, explicitness has value, so this is optional.

flashinfer/rope/rope.py (1)

129-154: Invalid backend values are silently ignored, falling through to CUDA.

All backend dispatch points use if backend == "cute-dsl": ... <fallthrough to CUDA>. A typo like backend="cutedsl" or backend="cute_dsl" will silently use the CUDA path instead of raising an error. Consider validating the backend value up-front.

Proposed fix (apply once, e.g., as a helper, or at each entry point)
+_VALID_BACKENDS = {"cuda", "cute-dsl"}
+
+def _validate_backend(backend: str) -> None:
+    if backend not in _VALID_BACKENDS:
+        raise ValueError(
+            f"Unknown backend {backend!r}. Must be one of {_VALID_BACKENDS}."
+        )
+

Then call _validate_backend(backend) at the top of each public function.

include/flashinfer/rope/kernels.cuh (1)

254-257: rope_chunk_size = rope_dim makes rope_chunks always 1 — the chunking machinery is effectively a no-op for the RoPE dimension.

This is clearly intentional scaffolding for future flexibility (e.g., splitting large rope dimensions across blocks). Just noting that rope_chunks = (rope_dim + rope_dim - 1) / rope_dim is always 1, so rope_chunk_idx is always 0 and elem_offset for RoPE sections is always 0. No action needed.

Also applies to: 698-700

Comment on lines +544 to +546
use_nonzero_offsets=use_nonzero_offsets,
)
speedup = cuda_ms / cute_ms
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential division by zero if cute_ms is zero.

speedup = cuda_ms / cute_ms will raise ZeroDivisionError if cute_ms happens to be 0. While unlikely with real GPU timings, a guard would be cheap insurance.

Proposed fix
-            speedup = cuda_ms / cute_ms
+            speedup = cuda_ms / cute_ms if cute_ms > 0 else float("inf")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
use_nonzero_offsets=use_nonzero_offsets,
)
speedup = cuda_ms / cute_ms
use_nonzero_offsets=use_nonzero_offsets,
)
speedup = cuda_ms / cute_ms if cute_ms > 0 else float("inf")
🤖 Prompt for AI Agents
In `@benchmarks/bench_rope_workloads.py` around lines 544 - 546, The calculation
speedup = cuda_ms / cute_ms can raise ZeroDivisionError if cute_ms is 0; modify
the code around where speedup is computed (variables cuda_ms and cute_ms) to
guard against zero or near-zero cute_ms (e.g., check if cute_ms == 0 or
abs(cute_ms) < epsilon) and handle it deterministically (return float('inf'),
None, or a default value and/or log a warning) so division is never attempted
when cute_ms is zero; update any downstream uses of speedup to account for the
chosen sentinel.

Comment on lines +874 to +892
if backend == "cute-dsl":
if not _is_cute_dsl_available():
raise RuntimeError(
"CuTe-DSL backend is not available. Please install CuTe-DSL."
)
from .cute_dsl.rope import apply_rope_with_indptr_cute_dsl

return apply_rope_with_indptr_cute_dsl(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=interleave,
rope_scale=rope_scale,
rope_theta=rope_theta,
)

# Default: CUDA C++ backend
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate unsupported backend values explicitly.
Unknown backend strings currently fall through to CUDA silently, which can hide configuration errors. Consider raising a ValueError in each backend‑switching API.

✅ Proposed pattern (apply to each backend‑switching API)
@@
     if backend == "cute-dsl":
         if not _is_cute_dsl_available():
             raise RuntimeError(
                 "CuTe-DSL backend is not available. Please install CuTe-DSL."
             )
         from .cute_dsl.rope import apply_rope_with_indptr_cute_dsl
@@
         return apply_rope_with_indptr_cute_dsl(
             q,
             k,
             indptr,
             offsets,
             rotary_dim=rotary_dim,
             interleave=interleave,
             rope_scale=rope_scale,
             rope_theta=rope_theta,
         )
+
+    if backend != "cuda":
+        raise ValueError(f"Unsupported backend: {backend}")

Also applies to: 969-986, 1105-1126, 1214-1234

🧰 Tools
🪛 Ruff (0.15.0)

[warning] 876-878: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@flashinfer/rope.py` around lines 874 - 892, The backend switch silently
defaults to the CUDA C++ backend when an unknown backend string is passed;
update the backend-selection logic (the branch handling backend, including the
"cute-dsl" case and the default CUDA path) to explicitly validate supported
values and raise a ValueError for unsupported strings instead of falling
through; for example, check backend against the allowed set (e.g., "cute-dsl",
"cuda" or whichever exact names your APIs expect) before importing/dispatching
and raise ValueError("Unsupported backend: {backend}") if it’s not recognized;
apply the same explicit validation pattern to the other backend-switching APIs
mentioned (the branches around apply_rope_with_indptr_cute_dsl, and the other
blocks at the indicated ranges).

Comment on lines +1 to +105
"""
Copyright (c) 2024-2026 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

FlashInfer RoPE (Rotary Positional Embeddings) Module
=====================================================

This module provides efficient implementations of RoPE for LLM inference.
It supports both CUDA C++ and CuTe-DSL backends.

Public APIs
-----------

Standard RoPE:
apply_rope : Apply RoPE using indptr/offsets (batched sequences)
apply_rope_inplace : Apply RoPE inplace using indptr/offsets
apply_rope_pos_ids : Apply RoPE using explicit position IDs
apply_rope_pos_ids_inplace : Apply RoPE inplace using position IDs

Llama 3.1 Style RoPE:
apply_llama31_rope : Apply Llama 3.1 RoPE with adaptive frequency scaling
apply_llama31_rope_inplace : Apply Llama 3.1 RoPE inplace
apply_llama31_rope_pos_ids : Apply Llama 3.1 RoPE using position IDs
apply_llama31_rope_pos_ids_inplace : Apply Llama 3.1 RoPE inplace using position IDs

RoPE with Precomputed cos/sin Cache (vLLM/SGLang compatible):
apply_rope_with_cos_sin_cache : Apply RoPE with precomputed cos/sin
apply_rope_with_cos_sin_cache_inplace : Apply RoPE with cos/sin cache inplace

Combined RoPE + Quantize Operations:
rope_quantize_fp8 : Apply RoPE and quantize to FP8
mla_rope_quantize_fp8 : Alias for rope_quantize_fp8
rope_quantize_fp8_append_paged_kv_cache : RoPE + quantize + append to paged cache

Backend Support
---------------
All APIs support a ``backend`` parameter:
- ``"cuda"`` (default): CUDA C++ backend with JIT compilation
- ``"cute-dsl"``: CuTe-DSL Python-based backend (requires CuTe-DSL installation)

Example
-------
>>> import torch
>>> import flashinfer
>>>
>>> # Basic RoPE with position IDs
>>> q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda")
>>> k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda")
>>> pos_ids = torch.arange(1024, dtype=torch.int32, device="cuda")
>>>
>>> q_rope, k_rope = flashinfer.apply_rope_pos_ids(q, k, pos_ids)
"""

from .rope import (
# Standard RoPE with indptr/offsets
apply_rope,
apply_rope_inplace,
# Standard RoPE with position IDs
apply_rope_pos_ids,
apply_rope_pos_ids_inplace,
# Llama 3.1 style RoPE with indptr/offsets
apply_llama31_rope,
apply_llama31_rope_inplace,
# Llama 3.1 style RoPE with position IDs
apply_llama31_rope_pos_ids,
apply_llama31_rope_pos_ids_inplace,
# RoPE with cos/sin cache
apply_rope_with_cos_sin_cache,
apply_rope_with_cos_sin_cache_inplace,
# RoPE + Quantize
rope_quantize_fp8,
mla_rope_quantize_fp8,
rope_quantize_fp8_append_paged_kv_cache,
)

__all__ = [
# Standard RoPE
"apply_rope",
"apply_rope_inplace",
"apply_rope_pos_ids",
"apply_rope_pos_ids_inplace",
# Llama 3.1 RoPE
"apply_llama31_rope",
"apply_llama31_rope_inplace",
"apply_llama31_rope_pos_ids",
"apply_llama31_rope_pos_ids_inplace",
# RoPE with cos/sin cache
"apply_rope_with_cos_sin_cache",
"apply_rope_with_cos_sin_cache_inplace",
# RoPE + Quantize
"rope_quantize_fp8",
"mla_rope_quantize_fp8",
"rope_quantize_fp8_append_paged_kv_cache",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether both flashinfer/rope.py and flashinfer/rope/__init__.py exist
echo "=== Checking for duplicate rope module ==="
fd -t f 'rope.py' flashinfer/ --max-depth 1
echo "---"
fd -t f '__init__.py' flashinfer/rope/ --max-depth 1
echo "---"
echo "=== All rope-related files ==="
fd 'rope' flashinfer/ --max-depth 2

Repository: flashinfer-ai/flashinfer

Length of output: 279


Critical: Remove duplicate module flashinfer/rope.py — it conflicts with the new package structure.

Both flashinfer/rope.py (old module file) and flashinfer/rope/__init__.py (new package) exist simultaneously. Python cannot resolve both a module file and a package directory with the same name, which causes the Mypy pipeline error: Duplicate module named "flashinfer.rope" (also at "flashinfer/rope.py").

Delete the old flashinfer/rope.py file — it has been superseded by the package at flashinfer/rope/ where the implementation now lives in flashinfer/rope/rope.py (which is imported here).

🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 1-3: Mypy error: Duplicate module named "flashinfer.rope" (also at "flashinfer/rope.py"). See mypy docs for mapping-file-paths-to-modules.

🪛 Ruff (0.15.0)

[warning] 87-105: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🤖 Prompt for AI Agents
In `@flashinfer/rope/__init__.py` around lines 1 - 105, Duplicate module conflict:
remove the old top-level module file flashinfer/rope.py so Python imports the
new package flashinfer.rope (the package's __init__.py imports symbols from
.rope like apply_rope, apply_rope_pos_ids, apply_llama31_rope,
rope_quantize_fp8, etc.). Delete flashinfer/rope.py from the repo and CI, ensure
no other files import the old top-level module path, and run the typecheck/test
suite to confirm the Duplicate module error is resolved.

Comment on lines +268 to +279
@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache")
def _fake_apply_rope_pos_ids_cos_sin_cache(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
cos_cache: torch.Tensor,
sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
interleave: bool,
) -> None:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fake op signature mismatch: _fake_apply_rope_pos_ids_cos_sin_cache has wrong parameters.

The real op (lines 247–265) accepts 7 parameters: (q, k, q_rope, k_rope, cos_sin_cache, pos_ids, interleave) with a single combined cos_sin_cache tensor. The fake op has 8 parameters, splitting it into cos_cache and sin_cache. This mismatch will cause failures during torch.compile or any tracing that relies on the fake op.

🐛 Proposed fix: align fake op signature with real op
 `@register_fake_op`("flashinfer::apply_rope_pos_ids_cos_sin_cache")
 def _fake_apply_rope_pos_ids_cos_sin_cache(
     q: torch.Tensor,
     k: torch.Tensor,
     q_rope: torch.Tensor,
     k_rope: torch.Tensor,
-    cos_cache: torch.Tensor,
-    sin_cache: torch.Tensor,
+    cos_sin_cache: torch.Tensor,
     pos_ids: torch.Tensor,
     interleave: bool,
 ) -> None:
     pass
🤖 Prompt for AI Agents
In `@flashinfer/rope/custom_ops.py` around lines 268 - 279, The fake op
_fake_apply_rope_pos_ids_cos_sin_cache has the wrong parameter list (it
currently takes cos_cache and sin_cache separately); change its signature to
match the real op by replacing the two params with a single cos_sin_cache
parameter so it accepts (q, k, q_rope, k_rope, cos_sin_cache, pos_ids,
interleave), update any internal references to use cos_sin_cache, and ensure the
`@register_fake_op` name remains "flashinfer::apply_rope_pos_ids_cos_sin_cache" to
match the real implementation.

Comment on lines +287 to +323
@register_custom_op(
"flashinfer::rope_quantize",
mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"),
)
def _rope_quantize(
q_rope_in: torch.Tensor,
k_rope_in: torch.Tensor,
q_nope_in: torch.Tensor,
k_nope_in: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
q_rope_out: torch.Tensor,
k_rope_out: torch.Tensor,
q_nope_out: torch.Tensor,
k_nope_out: torch.Tensor,
quant_scale_q: float,
quant_scale_kv: float,
interleave: bool,
enable_pdl: bool,
) -> None:
"""Combined RoPE application and FP8 quantization."""
get_rope_module().rope_quantize(
q_rope_in,
k_rope_in,
q_nope_in,
k_nope_in,
q_rope_out,
k_rope_out,
q_nope_out,
k_nope_out,
cos_sin_cache,
pos_ids,
quant_scale_q,
quant_scale_kv,
interleave,
enable_pdl,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find C++ rope_quantize function signature
echo "=== Searching for rope_quantize in csrc/rope.cu ==="
rg -n 'rope_quantize' csrc/rope.cu -A 10 -B 2

echo ""
echo "=== Searching for rope_quantize in all csrc files ==="
rg -n 'void rope_quantize|rope_quantize.*\(' csrc/ -A 5

echo ""
echo "=== Checking for TVM FFI or registry patterns ==="
rg -n 'TVM_REGISTER|rope_quantize' csrc/ -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 7371


Fix Python function signature to match the underlying C++ function argument order.

The Python signature declares cos_sin_cache and pos_ids before the output tensors, but the C++ function expects them after. The actual call correctly reorders arguments to match the C++ signature (line 308: outputs before cache/ids), but the Python signature is misleading. Update lines 291–306 to place cos_sin_cache and pos_ids after the output tensors (q_rope_out, k_rope_out, q_nope_out, k_nope_out) to match the C++ signature and remove the need for reordering in the call.

🤖 Prompt for AI Agents
In `@flashinfer/rope/custom_ops.py` around lines 287 - 323, The Python wrapper
_rope_quantize has the parameters cos_sin_cache and pos_ids placed before the
output tensors, but the underlying C++ expects those two after the outputs;
update the function signature of _rope_quantize to list q_rope_out, k_rope_out,
q_nope_out, k_nope_out before cos_sin_cache and pos_ids (so the order is
q_rope_in, k_rope_in, q_nope_in, k_nope_in, q_rope_out, k_rope_out, q_nope_out,
k_nope_out, cos_sin_cache, pos_ids, quant_scale_q, quant_scale_kv, interleave,
enable_pdl) and then adjust the call to get_rope_module().rope_quantize to use
matching parameter order (you can then remove the current manual reordering in
that call); this change affects the _rope_quantize function signature and its
invocation to match the C++ argument order.

Comment on lines +23 to +34
import math
from typing import Optional, Tuple

import torch

from .compile import (
_get_compiled_kernel,
_get_compiled_kernel_seq_heads,
_get_compiled_kernel_with_indptr,
_get_compiled_cos_sin_cache_kernel,
_get_compiled_cos_sin_cache_kernel_seq_heads,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add @flashinfer_api to exported CuTe‑DSL wrapper functions.
These functions are exported in __all__; they should be logged consistently with the rest of the Python API surface.

✅ Proposed decorator + import
@@
-import torch
+import torch
+
+from ...api_logging import flashinfer_api
@@
-def apply_rope_cute_dsl(
+@flashinfer_api
+def apply_rope_cute_dsl(
@@
-def apply_rope_with_indptr_cute_dsl(
+@flashinfer_api
+def apply_rope_with_indptr_cute_dsl(
@@
-def apply_llama31_rope_with_indptr_cute_dsl(
+@flashinfer_api
+def apply_llama31_rope_with_indptr_cute_dsl(
@@
-def apply_rope_with_cos_sin_cache_cute_dsl(
+@flashinfer_api
+def apply_rope_with_cos_sin_cache_cute_dsl(
As per coding guidelines: "Use `@flashinfer_api` decorator on Python functions for API logging with crash-safe input capture before execution."

Also applies to: 90-103, 312-326, 455-469, 495-503

🤖 Prompt for AI Agents
In `@flashinfer/rope/kernels/wrappers.py` around lines 23 - 34, The exported
CuTe-DSL wrapper functions in this module (the functions listed in __all__ and
the wrapper functions around the compiled kernels) are missing the
`@flashinfer_api` decorator; import flashinfer_api and add `@flashinfer_api` above
each exported wrapper function (including the functions in the ranges mentioned:
90-103, 312-326, 455-469, 495-503) so they participate in API logging and
crash-safe input capture; ensure you place the decorator directly above the
wrapper function definitions that call the compiled helpers
(_get_compiled_kernel, _get_compiled_kernel_seq_heads,
_get_compiled_kernel_with_indptr, _get_compiled_cos_sin_cache_kernel,
_get_compiled_cos_sin_cache_kernel_seq_heads) and that the import for
flashinfer_api is added to the top-level imports.

Comment on lines +147 to +169
# Validate inputs
assert q.ndim == 3, f"q must be 3D, got {q.ndim}D"
assert k.ndim == 3, f"k must be 3D, got {k.ndim}D"
assert q.size(0) == k.size(0), "q and k must have same nnz"
assert q.size(2) == k.size(2), "q and k must have same head_dim"
assert q.is_cuda, "q must be on CUDA"
assert k.is_cuda, "k must be on CUDA"
assert pos_ids.is_cuda, "pos_ids must be on CUDA"

nnz = q.size(0)
num_qo_heads = q.size(1)
num_kv_heads = k.size(1)
head_dim = q.size(2)

if rotary_dim is None:
rotary_dim = head_dim

assert rotary_dim <= head_dim, (
f"rotary_dim must be <= head_dim, got {rotary_dim} > {head_dim}"
)
assert rotary_dim % 2 == 0, "rotary_dim must be even"
assert head_dim % 2 == 0, "head_dim must be even"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add vector‑alignment and rotary‑dim guards for CuTe‑DSL kernels.
The vectorized kernels assume 8‑element alignment; non‑interleaved mode also needs rotary_dim divisible by 16 to avoid cross‑half pairing. Also, apply_rope_with_indptr_cute_dsl lacks the rotary_dim <= head_dim / evenness checks present in apply_rope_cute_dsl.

🐛 Proposed guards (apply here and mirror in the indptr/cos‑sin‑cache wrappers)
@@
     assert rotary_dim <= head_dim, (
         f"rotary_dim must be <= head_dim, got {rotary_dim} > {head_dim}"
     )
     assert rotary_dim % 2 == 0, "rotary_dim must be even"
     assert head_dim % 2 == 0, "head_dim must be even"
+    if head_dim % 8 != 0:
+        raise ValueError("head_dim must be a multiple of 8 for CuTe-DSL RoPE kernels")
+    if (not interleave) and (rotary_dim % 16 != 0):
+        raise ValueError("rotary_dim must be a multiple of 16 in non-interleaved mode")
@@
     if rotary_dim is None:
         rotary_dim = head_dim
+    if rotary_dim > head_dim:
+        raise ValueError(f"rotary_dim must be <= head_dim, got {rotary_dim} > {head_dim}")
+    if rotary_dim % 2 != 0 or head_dim % 2 != 0:
+        raise ValueError("rotary_dim and head_dim must be even")
+    if head_dim % 8 != 0:
+        raise ValueError("head_dim must be a multiple of 8 for CuTe-DSL RoPE kernels")
+    if (not interleave) and (rotary_dim % 16 != 0):
+        raise ValueError("rotary_dim must be a multiple of 16 in non-interleaved mode")

Also applies to: 379-410

🤖 Prompt for AI Agents
In `@flashinfer/rope/kernels/wrappers.py` around lines 147 - 169, Add
vector-alignment and rotary-dim guards to the CuTe-DSL wrappers: in
apply_rope_cute_dsl and apply_rope_with_indptr_cute_dsl validate that head_dim
is 8-element aligned (e.g., assert head_dim % 8 == 0) and preserve the existing
rotary_dim checks by asserting rotary_dim <= head_dim and rotary_dim % 2 == 0;
additionally, when running in non-interleaved mode enforce rotary_dim is
divisible by 16 (e.g., assert rotary_dim % 16 == 0 when interleaving flag
indicates non-interleaved). Mirror these checks in both the indptr/cos-sin-cache
wrappers so both apply_rope_cute_dsl and apply_rope_with_indptr_cute_dsl perform
the same alignment and rotary-dimension validation.

Comment on lines +402 to +406

See Also
--------
apply_rope_inplace
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Incorrect See Also cross-reference.

Line 405 references apply_rope_inplace, but apply_rope_pos_ids's inplace counterpart is apply_rope_pos_ids_inplace.

Fix
     See Also
     --------
-    apply_rope_inplace
+    apply_rope_pos_ids_inplace
🤖 Prompt for AI Agents
In `@flashinfer/rope/rope.py` around lines 402 - 406, The See Also cross-reference
is incorrect: update the reference in the docstring of apply_rope_pos_ids to
point to its true inplace counterpart apply_rope_pos_ids_inplace (replace
apply_rope_inplace with apply_rope_pos_ids_inplace) so the documentation
correctly links apply_rope_pos_ids to apply_rope_pos_ids_inplace.

Comment on lines +795 to +824
} else if (by < k_nope_end) {
// K Non-RoPE processing & cache append
uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks;
uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks;
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;

DType* k_nope_in_ptr;
if constexpr (IS_MLA) {
k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset;
} else {
k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
k_nope_in_stride, k_nope_in_stride_h);
}

vec_t<float, vec_size> k_nope_vec;
k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
}

if constexpr (IS_MLA) {
QuantType* ckv_ptr =
paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
k_nope_vec.cast_store(ckv_ptr);
} else {
QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx,
rope_dim + elem_offset + tx * vec_size);
k_nope_vec.cast_store(k_ptr);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, verify the file exists and get basic stats
echo "=== File check ==="
ls -la include/flashinfer/rope/kernels.cuh 2>&1 || echo "File not found"

echo ""
echo "=== File line count ==="
wc -l include/flashinfer/rope/kernels.cuh

echo ""
echo "=== Find scale_store_partial_chunk usage ==="
rg -n "scale_store_partial_chunk" include/flashinfer/rope/kernels.cuh -C 3

Repository: flashinfer-ai/flashinfer

Length of output: 1009


🏁 Script executed:

#!/bin/bash
# Extract specific sections mentioned in the review
echo "=== Lines 795-824 (K-nope section) ==="
sed -n '795,824p' include/flashinfer/rope/kernels.cuh

echo ""
echo "=== Lines 851-873 (Q-nope section) ==="
sed -n '851,873p' include/flashinfer/rope/kernels.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 2447


🏁 Script executed:

#!/bin/bash
# Look at RopeQuantizeKernel implementation for comparison
echo "=== Search for RopeQuantizeKernel function ==="
rg -n "RopeQuantizeKernel" include/flashinfer/rope/kernels.cuh | head -5

echo ""
echo "=== Extract lines around the RopeQuantizeKernel to see the pattern ==="
rg -n "scale_store_partial_chunk" include/flashinfer/rope/kernels.cuh -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1426


🏁 Script executed:

#!/bin/bash
# Search for scale_store_partial_chunk definition
echo "=== Find scale_store_partial_chunk definition ==="
rg -n "scale_store_partial_chunk" include/ -t cuda -t cpp -t cuh --type-add 'cuda:*.cuh'

Repository: flashinfer-ai/flashinfer

Length of output: 162


🏁 Script executed:

#!/bin/bash
# Search for scale_store_partial_chunk definition without type filtering
echo "=== Find scale_store_partial_chunk definition ==="
rg -n "scale_store_partial_chunk" include/flashinfer/rope/ -B 2 -A 10

echo ""
echo "=== Search for function definition of scale_store_partial_chunk ==="
rg -n "template.*scale_store_partial_chunk|scale_store_partial_chunk.*{" include/flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 2619


🏁 Script executed:

#!/bin/bash
# Get the function signature and definition
echo "=== Get RopeQuantizeAppendPagedKVCacheKernel function start ==="
rg -n "RopeQuantizeAppendPagedKVCacheKernel" include/flashinfer/rope/kernels.cuh -A 20

echo ""
echo "=== Check for any bounds checking in the paged kernel nope sections ==="
sed -n '700,880p' include/flashinfer/rope/kernels.cuh | grep -n "chunk_valid\|no_rope_dim\|elem_offset"

Repository: flashinfer-ai/flashinfer

Length of output: 3893


🏁 Script executed:

#!/bin/bash
# Look at the specific cast_load and cast_store calls in the paged kernel
echo "=== Examine K-nope section more closely ==="
sed -n '790,835p' include/flashinfer/rope/kernels.cuh

echo ""
echo "=== Examine Q-nope section more closely ==="
sed -n '845,880p' include/flashinfer/rope/kernels.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 3478


Add bounds checking for non-RoPE sections in RopeQuantizeAppendPagedKVCacheKernel.

The K-nope (line 810) and Q-nope (line 866) sections perform cast_load/cast_store without checking if elem_offset + tx * vec_size < no_rope_dim. When no_rope_dim is not a multiple of rope_chunk_size (= rope_dim), the last nope chunk will have threads reading/writing beyond the valid boundary. The non-paged RopeQuantizeKernel correctly uses scale_store_partial_chunk with chunk_valid bounds for these sections (lines 345–349, 364–368). Apply the same bounds protection to the paged kernel's nope sections, or verify that the memory layout guarantees safe out-of-bounds access.

🤖 Prompt for AI Agents
In `@include/flashinfer/rope/kernels.cuh` around lines 795 - 824,
RopeQuantizeAppendPagedKVCacheKernel currently does unchecked
cast_load/cast_store in the K-nope and Q-nope paths (uses k_nope_in,
elem_offset, tx, vec_size, cast_load/cast_store and paged_kv_like.get_ckv_ptr /
get_k_ptr), which can read/write past no_rope_dim when no_rope_dim %
rope_chunk_size != 0; change these sections to mirror the non-paged kernel by
computing the per-chunk valid count (chunk_valid = max(0, min(vec_size,
no_rope_dim - (elem_offset + tx*vec_size)))) and use the existing partial-store
helper (scale_store_partial_chunk) or otherwise guard loads/stores by
chunk_valid so that only valid elements are read/written; keep the same
quant_scale_kv scaling before storing and apply the check for both K-nope and
Q-nope branches.

Comment on lines +675 to +679
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard Llama‑3.1 smooth parameters when freq factors are equal.
high_freq_factor == low_freq_factor triggers division by zero and can yield NaNs; the CuTe‑DSL path already guards this. Please add the same protection here.

🐛 Proposed fix (apply in both Llama31 launcher functions)
-  float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
-  float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);
+  float smooth_a = 0.f;
+  float smooth_b = 0.f;
+  if (high_freq_factor != low_freq_factor) {
+    smooth_a = old_context_length / (2 * M_PI * (high_freq_factor - low_freq_factor));
+    smooth_b = -1.0f / ((high_freq_factor / low_freq_factor) - 1.0f);
+  }

Also applies to: 740-743

🤖 Prompt for AI Agents
In `@include/flashinfer/rope/launchers.cuh` around lines 675 - 679, The
computation of smooth_a and smooth_b can divide by zero when high_freq_factor ==
low_freq_factor; update both Llama‑3.1 launcher functions to guard this by
checking if fabsf(high_freq_factor - low_freq_factor) < 1e-6f (or similar small
epsilon) and in that case set smooth_a = 0.0f and smooth_b = 0.0f (matching the
CuTe‑DSL path); otherwise compute smooth_a and smooth_b with the existing
formulas using high_freq_factor and low_freq_factor. Ensure you apply the same
guard for the other occurrence of smooth_a/smooth_b (the lines around 740-743)
so no division by zero occurs.

Comment on lines +499 to +505
cuda_times = bench_gpu_time(run_cuda)
cuda_ms = np.median(cuda_times)

cute_times = bench_gpu_time(run_cute)
cute_ms = np.median(cute_times)

return cuda_ms, cute_ms
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine, but I recommend enable_cupti=True and dry_run_iters=10 and repeat_iters=30 for RoPe. CUPTI for more accurate kernel time measurements. Iteration count for faster benchmarking.

See bench_gpu_time definition

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logical organization of new files in rope/kernels/ seems a bit strange. It appears that

  • compile.py could be absorbed into wrappers.py
  • ptx_ops.py could probably be absorbed into helpers.py and then renamed as utils.py

These files are not too long per se, so I don't see the need to split out into multiple files.

Index type for indptr/offsets: "int32" or "int64". Default "int32".
"""
dtype = get_cutlass_dtype(dtype_str)
idtype = get_cutlass_idtype(idtype_str)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you take a very close look at the lifetime of these dtypes, you go from torch.dtype to str to cutlass data types. The intermediate string conversion is an unnecessary step that likely adds a small bit to overhead.

The overhead is generally small and could be 1 us, but RoPe kernels run fast so it might make a small difference.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #43935806: 12/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants