[flashinfer.rope] refactor CUDA backend, add CuTe-DSL backend for unfused RoPE APIs#2470
[flashinfer.rope] refactor CUDA backend, add CuTe-DSL backend for unfused RoPE APIs#2470kahyunnam wants to merge 10 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR introduces comprehensive CuTe-DSL backend support for RoPE (Rotary Positional Embeddings) operations, adding a new backend parameter to multiple RoPE-related functions and implementing kernel variants for the CuTe-DSL path alongside existing CUDA kernels, along with a new benchmark suite. Changes
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @kahyunnam, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new, highly optimized backend for Rotary Positional Embeddings (RoPE) leveraging the CuTe-DSL framework. This refactoring provides an alternative implementation to the existing CUDA C++ kernels, supporting both interleaved and non-interleaved RoPE styles, along with Llama 3.1 frequency scaling. The changes integrate this new backend into the public API, allowing users to select their preferred implementation, and include comprehensive unit tests to ensure correctness. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new CuTe-DSL backend for RoPE kernels as an alternative to the existing CUDA C++ implementation. The changes include a new flashinfer/cute_dsl/rope.py file with the kernel implementations and modifications to flashinfer/rope.py to add a backend switch. The tests are also updated to cover the new backend.
My review focuses on the new CuTe-DSL implementation. I've identified opportunities for performance optimization and code refactoring to improve maintainability. Specifically, I've suggested using a more efficient PTX instruction for sincos, vectorizing a CPU-bound helper function, and refactoring duplicated code within the main kernel loops.
flashinfer/cute_dsl/rope.py
Outdated
| def sincos_approx(x: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: | ||
| """ | ||
| Compute approximate sin and cos simultaneously using PTX. | ||
|
|
||
| This is more efficient than calling sin_approx and cos_approx separately | ||
| as it computes both values in a single operation. | ||
|
|
||
| Returns (sin_val, cos_val). | ||
| """ | ||
| # PTX doesn't have a combined sincos.approx, so we compute separately | ||
| # but this function serves as documentation that both are needed | ||
| # and allows future optimization if PTX adds such instruction | ||
| sin_val = Float32( | ||
| llvm.inline_asm( | ||
| T.f32(), | ||
| [Float32(x).ir_value(loc=loc, ip=ip)], | ||
| "sin.approx.f32 $0, $1;", | ||
| "=f,f", | ||
| has_side_effects=False, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) | ||
| ) | ||
| cos_val = Float32( | ||
| llvm.inline_asm( | ||
| T.f32(), | ||
| [Float32(x).ir_value(loc=loc, ip=ip)], | ||
| "cos.approx.f32 $0, $1;", | ||
| "=f,f", | ||
| has_side_effects=False, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) | ||
| ) | ||
| return sin_val, cos_val |
There was a problem hiding this comment.
The comment on line 129 is incorrect. PTX ISA has supported a combined sin/cos approximate instruction (sinc.approx.f32) since SM 7.0, which is more efficient than calling sin.approx.f32 and cos.approx.f32 separately. I suggest refactoring this function to use a single llvm.inline_asm call with sinc.approx.f32 for better performance.
def sincos_approx(x: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
"""
Compute approximate sin and cos simultaneously using PTX sinc.approx.f32.
This is more efficient than calling sin_approx and cos_approx separately
as it computes both values in a single operation.
Returns (sin_val, cos_val).
"""
result = llvm.inline_asm(
llvm.StructType.get_literal([T.f32(), T.f32()]),
[Float32(x).ir_value(loc=loc, ip=ip)],
"""{
.reg .v2 .f32 sincos;
sinc.approx.f32 sincos, $2;
mov.b32 $0, sincos.x;
mov.b32 $1, sincos.y;
}""",
"=f,=f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
sin_val = llvm.extractvalue(T.f32(), result, [0], loc=loc, ip=ip)
cos_val = llvm.extractvalue(T.f32(), result, [1], loc=loc, ip=ip)
return Float32(sin_val), Float32(cos_val)There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/rope.py (1)
771-905:⚠️ Potential issue | 🟡 MinorValidate backend values to avoid silent CUDA fallback.
If
backendis anything other than"cuda"or"cute-dsl", the code currently falls through to CUDA. Please raise a clear error and apply the same guard in the other backend‑parameterized APIs.Suggested guard for apply_rope (mirror in other APIs)
if rotary_dim is None: rotary_dim = q.size(-1) if backend == "cute-dsl": if not _is_cute_dsl_available(): raise RuntimeError( "CuTe-DSL backend is not available. Please install CuTe-DSL." ) from .cute_dsl.rope import apply_rope_with_indptr_cute_dsl return apply_rope_with_indptr_cute_dsl( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=interleave, rope_scale=rope_scale, rope_theta=rope_theta, ) - # Default: CUDA C++ backend + if backend != "cuda": + raise ValueError(f"Unsupported backend: {backend}") + # Default: CUDA C++ backend _apply_rope( q, k, q_rope, k_rope, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta, )
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/rope.py`:
- Around line 1566-1569: The __all__ export list is unsorted and triggers Ruff
RUF022; reorder the entries in the __all__ list alphabetically so
"RopeKernelInterleavedVec" appears before "RopeKernelNonInterleavedVec" (i.e.,
sort the list containing RopeKernelNonInterleavedVec and
RopeKernelInterleavedVec) to satisfy the linter.
- Around line 1322-1330: Add explicit vector-alignment assertions for head_dim
and rotary_dim to match the kernel's 8-element vector and single-pair thread
assumptions: ensure head_dim is a multiple of 8 (replace or augment the existing
head_dim % 2 check) and, depending on the layout mode, ensure rotary_dim is
aligned so the kernel won't straddle halves — specifically assert rotary_dim % 8
== 0 when using interleaved vectors, and assert (rotary_dim // 2) % 8 == 0 in
the non-interleaved/half-pair case; update the validation near the existing
rotary_dim/head_dim checks (the block referencing rotary_dim and head_dim) and
include clear assertion messages referencing rotary_dim and head_dim.
In `@tests/attention/test_rope.py`:
- Around line 47-53: The cute-dsl test gating only checks availability and
inplace support but misses GPU compute-capability guards; update the two
CuTe-DSL skip blocks (the branches checking backend == "cute-dsl" that call
is_cute_dsl_available() and pytest.skip()) to also call the appropriate
flashinfer.utils helpers (e.g., get_compute_capability() and the relevant
predicate like is_sm90a_supported() or a project-specific is_sm_supported(...)
helper) and skip with pytest.skip(...) when the current GPU SM is unsupported;
ensure you import the helpers from flashinfer.utils at the top of the test file
and apply the same change to the other CuTe-DSL skip site mentioned in the
review.
🧹 Nitpick comments (1)
flashinfer/cute_dsl/rope.py (1)
1435-1445: Avoid per‑sequence GPU syncs when building pos_ids.The Python loop with
.item()on CUDA tensors forces a sync per sequence. A vectorized construction keeps this on‑GPU and scales better for large batches.Vectorized alternative
- # Create output tensor - pos_ids = torch.empty(nnz, dtype=torch.int32, device=device) - - # For each sequence, compute positions - # This is a simple CPU loop - could be optimized with a Triton kernel if needed - for i in range(batch_size): - start = indptr[i].item() - end = indptr[i + 1].item() - offset = offsets[i].item() - seq_len = end - start - if seq_len > 0: - pos_ids[start:end] = ( - torch.arange(seq_len, dtype=torch.int32, device=device) + offset - ) + lengths = (indptr[1:] - indptr[:-1]).to(torch.int32) + seq_ids = torch.repeat_interleave( + torch.arange(batch_size, device=device, dtype=torch.int32), lengths + ) + base = indptr[:-1].to(torch.int32) + pos_ids = ( + torch.arange(nnz, device=device, dtype=torch.int32) + - base[seq_ids] + + offsets.to(torch.int32)[seq_ids] + )
tests/attention/test_rope.py
Outdated
| # Skip cute-dsl backend if not available or for inplace operations | ||
| if backend == "cute-dsl": | ||
| if not is_cute_dsl_available(): | ||
| pytest.skip("CuTe-DSL not available") | ||
| if inplace: | ||
| pytest.skip("CuTe-DSL backend does not support inplace operations") | ||
|
|
There was a problem hiding this comment.
Add compute-capability skips for CuTe‑DSL tests.
The cute‑dsl parametrization only checks availability/inplace; on unsupported SMs these tests can still run and fail. Please gate cute‑dsl runs with flashinfer.utils compute‑capability helpers in both skip blocks.
Suggested guard (adjust helper to the actual CuTe‑DSL support matrix)
import flashinfer
from flashinfer.cute_dsl.utils import is_cute_dsl_available
+from flashinfer.utils import get_compute_capability, is_sm90a_supported
...
if backend == "cute-dsl":
if not is_cute_dsl_available():
pytest.skip("CuTe-DSL not available")
+ cc = get_compute_capability()
+ if not is_sm90a_supported(cc):
+ pytest.skip(f"CuTe-DSL backend not supported on {cc}")
if inplace:
pytest.skip("CuTe-DSL backend does not support inplace operations")As per coding guidelines: tests/**/*.py: Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.
Also applies to: 177-183
🤖 Prompt for AI Agents
In `@tests/attention/test_rope.py` around lines 47 - 53, The cute-dsl test gating
only checks availability and inplace support but misses GPU compute-capability
guards; update the two CuTe-DSL skip blocks (the branches checking backend ==
"cute-dsl" that call is_cute_dsl_available() and pytest.skip()) to also call the
appropriate flashinfer.utils helpers (e.g., get_compute_capability() and the
relevant predicate like is_sm90a_supported() or a project-specific
is_sm_supported(...) helper) and skip with pytest.skip(...) when the current GPU
SM is unsupported; ensure you import the helpers from flashinfer.utils at the
top of the test file and apply the same change to the other CuTe-DSL skip site
mentioned in the review.
b777566 to
628c961
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 11
🤖 Fix all issues with AI agents
In `@benchmarks/bench_rope_workloads.py`:
- Around line 544-546: The calculation speedup = cuda_ms / cute_ms can raise
ZeroDivisionError if cute_ms is 0; modify the code around where speedup is
computed (variables cuda_ms and cute_ms) to guard against zero or near-zero
cute_ms (e.g., check if cute_ms == 0 or abs(cute_ms) < epsilon) and handle it
deterministically (return float('inf'), None, or a default value and/or log a
warning) so division is never attempted when cute_ms is zero; update any
downstream uses of speedup to account for the chosen sentinel.
In `@flashinfer/rope.py`:
- Around line 874-892: The backend switch silently defaults to the CUDA C++
backend when an unknown backend string is passed; update the backend-selection
logic (the branch handling backend, including the "cute-dsl" case and the
default CUDA path) to explicitly validate supported values and raise a
ValueError for unsupported strings instead of falling through; for example,
check backend against the allowed set (e.g., "cute-dsl", "cuda" or whichever
exact names your APIs expect) before importing/dispatching and raise
ValueError("Unsupported backend: {backend}") if it’s not recognized; apply the
same explicit validation pattern to the other backend-switching APIs mentioned
(the branches around apply_rope_with_indptr_cute_dsl, and the other blocks at
the indicated ranges).
In `@flashinfer/rope/__init__.py`:
- Around line 1-105: Duplicate module conflict: remove the old top-level module
file flashinfer/rope.py so Python imports the new package flashinfer.rope (the
package's __init__.py imports symbols from .rope like apply_rope,
apply_rope_pos_ids, apply_llama31_rope, rope_quantize_fp8, etc.). Delete
flashinfer/rope.py from the repo and CI, ensure no other files import the old
top-level module path, and run the typecheck/test suite to confirm the Duplicate
module error is resolved.
In `@flashinfer/rope/custom_ops.py`:
- Around line 268-279: The fake op _fake_apply_rope_pos_ids_cos_sin_cache has
the wrong parameter list (it currently takes cos_cache and sin_cache
separately); change its signature to match the real op by replacing the two
params with a single cos_sin_cache parameter so it accepts (q, k, q_rope,
k_rope, cos_sin_cache, pos_ids, interleave), update any internal references to
use cos_sin_cache, and ensure the `@register_fake_op` name remains
"flashinfer::apply_rope_pos_ids_cos_sin_cache" to match the real implementation.
- Around line 287-323: The Python wrapper _rope_quantize has the parameters
cos_sin_cache and pos_ids placed before the output tensors, but the underlying
C++ expects those two after the outputs; update the function signature of
_rope_quantize to list q_rope_out, k_rope_out, q_nope_out, k_nope_out before
cos_sin_cache and pos_ids (so the order is q_rope_in, k_rope_in, q_nope_in,
k_nope_in, q_rope_out, k_rope_out, q_nope_out, k_nope_out, cos_sin_cache,
pos_ids, quant_scale_q, quant_scale_kv, interleave, enable_pdl) and then adjust
the call to get_rope_module().rope_quantize to use matching parameter order (you
can then remove the current manual reordering in that call); this change affects
the _rope_quantize function signature and its invocation to match the C++
argument order.
In `@flashinfer/rope/kernels/kernels.py`:
- Around line 71-88: The constructor (__init__) assumes 8-element vector
alignment (elems_per_thread=8) and half-rotary pairs, which breaks when head_dim
% 8 != 0 or rotary_dim % 16 != 0 and can make bdx==0 leading to bdy
divide-by-zero; add explicit validation at the start of __init__ to raise a
clear ValueError if head_dim % 8 != 0 or rotary_dim % 16 != 0, and guard the
bdx/bdy calculation by ensuring bdx = max(1, head_dim // self.elems_per_thread)
and computing num_threads and bdy accordingly (e.g., num_threads = max(128, bdx)
then bdy = max(1, num_threads // bdx)); apply the same validation/guards in the
other non-interleaved kernel variants that use elems_per_thread, bdx,
num_threads, bdy, head_dim, rotary_dim, and half_rotary.
In `@flashinfer/rope/kernels/wrappers.py`:
- Around line 147-169: Add vector-alignment and rotary-dim guards to the
CuTe-DSL wrappers: in apply_rope_cute_dsl and apply_rope_with_indptr_cute_dsl
validate that head_dim is 8-element aligned (e.g., assert head_dim % 8 == 0) and
preserve the existing rotary_dim checks by asserting rotary_dim <= head_dim and
rotary_dim % 2 == 0; additionally, when running in non-interleaved mode enforce
rotary_dim is divisible by 16 (e.g., assert rotary_dim % 16 == 0 when
interleaving flag indicates non-interleaved). Mirror these checks in both the
indptr/cos-sin-cache wrappers so both apply_rope_cute_dsl and
apply_rope_with_indptr_cute_dsl perform the same alignment and rotary-dimension
validation.
- Around line 23-34: The exported CuTe-DSL wrapper functions in this module (the
functions listed in __all__ and the wrapper functions around the compiled
kernels) are missing the `@flashinfer_api` decorator; import flashinfer_api and
add `@flashinfer_api` above each exported wrapper function (including the
functions in the ranges mentioned: 90-103, 312-326, 455-469, 495-503) so they
participate in API logging and crash-safe input capture; ensure you place the
decorator directly above the wrapper function definitions that call the compiled
helpers (_get_compiled_kernel, _get_compiled_kernel_seq_heads,
_get_compiled_kernel_with_indptr, _get_compiled_cos_sin_cache_kernel,
_get_compiled_cos_sin_cache_kernel_seq_heads) and that the import for
flashinfer_api is added to the top-level imports.
In `@flashinfer/rope/rope.py`:
- Around line 402-406: The See Also cross-reference is incorrect: update the
reference in the docstring of apply_rope_pos_ids to point to its true inplace
counterpart apply_rope_pos_ids_inplace (replace apply_rope_inplace with
apply_rope_pos_ids_inplace) so the documentation correctly links
apply_rope_pos_ids to apply_rope_pos_ids_inplace.
In `@include/flashinfer/rope/kernels.cuh`:
- Around line 795-824: RopeQuantizeAppendPagedKVCacheKernel currently does
unchecked cast_load/cast_store in the K-nope and Q-nope paths (uses k_nope_in,
elem_offset, tx, vec_size, cast_load/cast_store and paged_kv_like.get_ckv_ptr /
get_k_ptr), which can read/write past no_rope_dim when no_rope_dim %
rope_chunk_size != 0; change these sections to mirror the non-paged kernel by
computing the per-chunk valid count (chunk_valid = max(0, min(vec_size,
no_rope_dim - (elem_offset + tx*vec_size)))) and use the existing partial-store
helper (scale_store_partial_chunk) or otherwise guard loads/stores by
chunk_valid so that only valid elements are read/written; keep the same
quant_scale_kv scaling before storing and apply the check for both K-nope and
Q-nope branches.
In `@include/flashinfer/rope/launchers.cuh`:
- Around line 675-679: The computation of smooth_a and smooth_b can divide by
zero when high_freq_factor == low_freq_factor; update both Llama‑3.1 launcher
functions to guard this by checking if fabsf(high_freq_factor - low_freq_factor)
< 1e-6f (or similar small epsilon) and in that case set smooth_a = 0.0f and
smooth_b = 0.0f (matching the CuTe‑DSL path); otherwise compute smooth_a and
smooth_b with the existing formulas using high_freq_factor and low_freq_factor.
Ensure you apply the same guard for the other occurrence of smooth_a/smooth_b
(the lines around 740-743) so no division by zero occurs.
🧹 Nitpick comments (3)
benchmarks/bench_rope_workloads.py (1)
195-505: Consider reducing duplication inbenchmark_api.This ~310-line function consists of a long
if/elifchain where each branch defines nearly identicalrun_cuda/run_cuteclosures differing only in the API name and a few arguments. A data-driven approach (mapping API name → callable + arg builder) would cut this significantly and make adding new APIs a one-liner.That said, for a benchmark script, explicitness has value, so this is optional.
flashinfer/rope/rope.py (1)
129-154: Invalidbackendvalues are silently ignored, falling through to CUDA.All backend dispatch points use
if backend == "cute-dsl": ... <fallthrough to CUDA>. A typo likebackend="cutedsl"orbackend="cute_dsl"will silently use the CUDA path instead of raising an error. Consider validating the backend value up-front.Proposed fix (apply once, e.g., as a helper, or at each entry point)
+_VALID_BACKENDS = {"cuda", "cute-dsl"} + +def _validate_backend(backend: str) -> None: + if backend not in _VALID_BACKENDS: + raise ValueError( + f"Unknown backend {backend!r}. Must be one of {_VALID_BACKENDS}." + ) +Then call
_validate_backend(backend)at the top of each public function.include/flashinfer/rope/kernels.cuh (1)
254-257:rope_chunk_size = rope_dimmakesrope_chunksalways 1 — the chunking machinery is effectively a no-op for the RoPE dimension.This is clearly intentional scaffolding for future flexibility (e.g., splitting large rope dimensions across blocks). Just noting that
rope_chunks = (rope_dim + rope_dim - 1) / rope_dimis always 1, sorope_chunk_idxis always 0 andelem_offsetfor RoPE sections is always 0. No action needed.Also applies to: 698-700
| use_nonzero_offsets=use_nonzero_offsets, | ||
| ) | ||
| speedup = cuda_ms / cute_ms |
There was a problem hiding this comment.
Potential division by zero if cute_ms is zero.
speedup = cuda_ms / cute_ms will raise ZeroDivisionError if cute_ms happens to be 0. While unlikely with real GPU timings, a guard would be cheap insurance.
Proposed fix
- speedup = cuda_ms / cute_ms
+ speedup = cuda_ms / cute_ms if cute_ms > 0 else float("inf")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| use_nonzero_offsets=use_nonzero_offsets, | |
| ) | |
| speedup = cuda_ms / cute_ms | |
| use_nonzero_offsets=use_nonzero_offsets, | |
| ) | |
| speedup = cuda_ms / cute_ms if cute_ms > 0 else float("inf") |
🤖 Prompt for AI Agents
In `@benchmarks/bench_rope_workloads.py` around lines 544 - 546, The calculation
speedup = cuda_ms / cute_ms can raise ZeroDivisionError if cute_ms is 0; modify
the code around where speedup is computed (variables cuda_ms and cute_ms) to
guard against zero or near-zero cute_ms (e.g., check if cute_ms == 0 or
abs(cute_ms) < epsilon) and handle it deterministically (return float('inf'),
None, or a default value and/or log a warning) so division is never attempted
when cute_ms is zero; update any downstream uses of speedup to account for the
chosen sentinel.
| if backend == "cute-dsl": | ||
| if not _is_cute_dsl_available(): | ||
| raise RuntimeError( | ||
| "CuTe-DSL backend is not available. Please install CuTe-DSL." | ||
| ) | ||
| from .cute_dsl.rope import apply_rope_with_indptr_cute_dsl | ||
|
|
||
| return apply_rope_with_indptr_cute_dsl( | ||
| q, | ||
| k, | ||
| indptr, | ||
| offsets, | ||
| rotary_dim=rotary_dim, | ||
| interleave=interleave, | ||
| rope_scale=rope_scale, | ||
| rope_theta=rope_theta, | ||
| ) | ||
|
|
||
| # Default: CUDA C++ backend |
There was a problem hiding this comment.
Validate unsupported backend values explicitly.
Unknown backend strings currently fall through to CUDA silently, which can hide configuration errors. Consider raising a ValueError in each backend‑switching API.
✅ Proposed pattern (apply to each backend‑switching API)
@@
if backend == "cute-dsl":
if not _is_cute_dsl_available():
raise RuntimeError(
"CuTe-DSL backend is not available. Please install CuTe-DSL."
)
from .cute_dsl.rope import apply_rope_with_indptr_cute_dsl
@@
return apply_rope_with_indptr_cute_dsl(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=interleave,
rope_scale=rope_scale,
rope_theta=rope_theta,
)
+
+ if backend != "cuda":
+ raise ValueError(f"Unsupported backend: {backend}")Also applies to: 969-986, 1105-1126, 1214-1234
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 876-878: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@flashinfer/rope.py` around lines 874 - 892, The backend switch silently
defaults to the CUDA C++ backend when an unknown backend string is passed;
update the backend-selection logic (the branch handling backend, including the
"cute-dsl" case and the default CUDA path) to explicitly validate supported
values and raise a ValueError for unsupported strings instead of falling
through; for example, check backend against the allowed set (e.g., "cute-dsl",
"cuda" or whichever exact names your APIs expect) before importing/dispatching
and raise ValueError("Unsupported backend: {backend}") if it’s not recognized;
apply the same explicit validation pattern to the other backend-switching APIs
mentioned (the branches around apply_rope_with_indptr_cute_dsl, and the other
blocks at the indicated ranges).
| """ | ||
| Copyright (c) 2024-2026 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
|
|
||
| FlashInfer RoPE (Rotary Positional Embeddings) Module | ||
| ===================================================== | ||
|
|
||
| This module provides efficient implementations of RoPE for LLM inference. | ||
| It supports both CUDA C++ and CuTe-DSL backends. | ||
|
|
||
| Public APIs | ||
| ----------- | ||
|
|
||
| Standard RoPE: | ||
| apply_rope : Apply RoPE using indptr/offsets (batched sequences) | ||
| apply_rope_inplace : Apply RoPE inplace using indptr/offsets | ||
| apply_rope_pos_ids : Apply RoPE using explicit position IDs | ||
| apply_rope_pos_ids_inplace : Apply RoPE inplace using position IDs | ||
|
|
||
| Llama 3.1 Style RoPE: | ||
| apply_llama31_rope : Apply Llama 3.1 RoPE with adaptive frequency scaling | ||
| apply_llama31_rope_inplace : Apply Llama 3.1 RoPE inplace | ||
| apply_llama31_rope_pos_ids : Apply Llama 3.1 RoPE using position IDs | ||
| apply_llama31_rope_pos_ids_inplace : Apply Llama 3.1 RoPE inplace using position IDs | ||
|
|
||
| RoPE with Precomputed cos/sin Cache (vLLM/SGLang compatible): | ||
| apply_rope_with_cos_sin_cache : Apply RoPE with precomputed cos/sin | ||
| apply_rope_with_cos_sin_cache_inplace : Apply RoPE with cos/sin cache inplace | ||
|
|
||
| Combined RoPE + Quantize Operations: | ||
| rope_quantize_fp8 : Apply RoPE and quantize to FP8 | ||
| mla_rope_quantize_fp8 : Alias for rope_quantize_fp8 | ||
| rope_quantize_fp8_append_paged_kv_cache : RoPE + quantize + append to paged cache | ||
|
|
||
| Backend Support | ||
| --------------- | ||
| All APIs support a ``backend`` parameter: | ||
| - ``"cuda"`` (default): CUDA C++ backend with JIT compilation | ||
| - ``"cute-dsl"``: CuTe-DSL Python-based backend (requires CuTe-DSL installation) | ||
|
|
||
| Example | ||
| ------- | ||
| >>> import torch | ||
| >>> import flashinfer | ||
| >>> | ||
| >>> # Basic RoPE with position IDs | ||
| >>> q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda") | ||
| >>> k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") | ||
| >>> pos_ids = torch.arange(1024, dtype=torch.int32, device="cuda") | ||
| >>> | ||
| >>> q_rope, k_rope = flashinfer.apply_rope_pos_ids(q, k, pos_ids) | ||
| """ | ||
|
|
||
| from .rope import ( | ||
| # Standard RoPE with indptr/offsets | ||
| apply_rope, | ||
| apply_rope_inplace, | ||
| # Standard RoPE with position IDs | ||
| apply_rope_pos_ids, | ||
| apply_rope_pos_ids_inplace, | ||
| # Llama 3.1 style RoPE with indptr/offsets | ||
| apply_llama31_rope, | ||
| apply_llama31_rope_inplace, | ||
| # Llama 3.1 style RoPE with position IDs | ||
| apply_llama31_rope_pos_ids, | ||
| apply_llama31_rope_pos_ids_inplace, | ||
| # RoPE with cos/sin cache | ||
| apply_rope_with_cos_sin_cache, | ||
| apply_rope_with_cos_sin_cache_inplace, | ||
| # RoPE + Quantize | ||
| rope_quantize_fp8, | ||
| mla_rope_quantize_fp8, | ||
| rope_quantize_fp8_append_paged_kv_cache, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| # Standard RoPE | ||
| "apply_rope", | ||
| "apply_rope_inplace", | ||
| "apply_rope_pos_ids", | ||
| "apply_rope_pos_ids_inplace", | ||
| # Llama 3.1 RoPE | ||
| "apply_llama31_rope", | ||
| "apply_llama31_rope_inplace", | ||
| "apply_llama31_rope_pos_ids", | ||
| "apply_llama31_rope_pos_ids_inplace", | ||
| # RoPE with cos/sin cache | ||
| "apply_rope_with_cos_sin_cache", | ||
| "apply_rope_with_cos_sin_cache_inplace", | ||
| # RoPE + Quantize | ||
| "rope_quantize_fp8", | ||
| "mla_rope_quantize_fp8", | ||
| "rope_quantize_fp8_append_paged_kv_cache", | ||
| ] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether both flashinfer/rope.py and flashinfer/rope/__init__.py exist
echo "=== Checking for duplicate rope module ==="
fd -t f 'rope.py' flashinfer/ --max-depth 1
echo "---"
fd -t f '__init__.py' flashinfer/rope/ --max-depth 1
echo "---"
echo "=== All rope-related files ==="
fd 'rope' flashinfer/ --max-depth 2Repository: flashinfer-ai/flashinfer
Length of output: 279
Critical: Remove duplicate module flashinfer/rope.py — it conflicts with the new package structure.
Both flashinfer/rope.py (old module file) and flashinfer/rope/__init__.py (new package) exist simultaneously. Python cannot resolve both a module file and a package directory with the same name, which causes the Mypy pipeline error: Duplicate module named "flashinfer.rope" (also at "flashinfer/rope.py").
Delete the old flashinfer/rope.py file — it has been superseded by the package at flashinfer/rope/ where the implementation now lives in flashinfer/rope/rope.py (which is imported here).
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 1-3: Mypy error: Duplicate module named "flashinfer.rope" (also at "flashinfer/rope.py"). See mypy docs for mapping-file-paths-to-modules.
🪛 Ruff (0.15.0)
[warning] 87-105: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
In `@flashinfer/rope/__init__.py` around lines 1 - 105, Duplicate module conflict:
remove the old top-level module file flashinfer/rope.py so Python imports the
new package flashinfer.rope (the package's __init__.py imports symbols from
.rope like apply_rope, apply_rope_pos_ids, apply_llama31_rope,
rope_quantize_fp8, etc.). Delete flashinfer/rope.py from the repo and CI, ensure
no other files import the old top-level module path, and run the typecheck/test
suite to confirm the Duplicate module error is resolved.
| @register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache") | ||
| def _fake_apply_rope_pos_ids_cos_sin_cache( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| q_rope: torch.Tensor, | ||
| k_rope: torch.Tensor, | ||
| cos_cache: torch.Tensor, | ||
| sin_cache: torch.Tensor, | ||
| pos_ids: torch.Tensor, | ||
| interleave: bool, | ||
| ) -> None: | ||
| pass |
There was a problem hiding this comment.
Fake op signature mismatch: _fake_apply_rope_pos_ids_cos_sin_cache has wrong parameters.
The real op (lines 247–265) accepts 7 parameters: (q, k, q_rope, k_rope, cos_sin_cache, pos_ids, interleave) with a single combined cos_sin_cache tensor. The fake op has 8 parameters, splitting it into cos_cache and sin_cache. This mismatch will cause failures during torch.compile or any tracing that relies on the fake op.
🐛 Proposed fix: align fake op signature with real op
`@register_fake_op`("flashinfer::apply_rope_pos_ids_cos_sin_cache")
def _fake_apply_rope_pos_ids_cos_sin_cache(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
- cos_cache: torch.Tensor,
- sin_cache: torch.Tensor,
+ cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
interleave: bool,
) -> None:
pass🤖 Prompt for AI Agents
In `@flashinfer/rope/custom_ops.py` around lines 268 - 279, The fake op
_fake_apply_rope_pos_ids_cos_sin_cache has the wrong parameter list (it
currently takes cos_cache and sin_cache separately); change its signature to
match the real op by replacing the two params with a single cos_sin_cache
parameter so it accepts (q, k, q_rope, k_rope, cos_sin_cache, pos_ids,
interleave), update any internal references to use cos_sin_cache, and ensure the
`@register_fake_op` name remains "flashinfer::apply_rope_pos_ids_cos_sin_cache" to
match the real implementation.
| @register_custom_op( | ||
| "flashinfer::rope_quantize", | ||
| mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"), | ||
| ) | ||
| def _rope_quantize( | ||
| q_rope_in: torch.Tensor, | ||
| k_rope_in: torch.Tensor, | ||
| q_nope_in: torch.Tensor, | ||
| k_nope_in: torch.Tensor, | ||
| cos_sin_cache: torch.Tensor, | ||
| pos_ids: torch.Tensor, | ||
| q_rope_out: torch.Tensor, | ||
| k_rope_out: torch.Tensor, | ||
| q_nope_out: torch.Tensor, | ||
| k_nope_out: torch.Tensor, | ||
| quant_scale_q: float, | ||
| quant_scale_kv: float, | ||
| interleave: bool, | ||
| enable_pdl: bool, | ||
| ) -> None: | ||
| """Combined RoPE application and FP8 quantization.""" | ||
| get_rope_module().rope_quantize( | ||
| q_rope_in, | ||
| k_rope_in, | ||
| q_nope_in, | ||
| k_nope_in, | ||
| q_rope_out, | ||
| k_rope_out, | ||
| q_nope_out, | ||
| k_nope_out, | ||
| cos_sin_cache, | ||
| pos_ids, | ||
| quant_scale_q, | ||
| quant_scale_kv, | ||
| interleave, | ||
| enable_pdl, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find C++ rope_quantize function signature
echo "=== Searching for rope_quantize in csrc/rope.cu ==="
rg -n 'rope_quantize' csrc/rope.cu -A 10 -B 2
echo ""
echo "=== Searching for rope_quantize in all csrc files ==="
rg -n 'void rope_quantize|rope_quantize.*\(' csrc/ -A 5
echo ""
echo "=== Checking for TVM FFI or registry patterns ==="
rg -n 'TVM_REGISTER|rope_quantize' csrc/ -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 7371
Fix Python function signature to match the underlying C++ function argument order.
The Python signature declares cos_sin_cache and pos_ids before the output tensors, but the C++ function expects them after. The actual call correctly reorders arguments to match the C++ signature (line 308: outputs before cache/ids), but the Python signature is misleading. Update lines 291–306 to place cos_sin_cache and pos_ids after the output tensors (q_rope_out, k_rope_out, q_nope_out, k_nope_out) to match the C++ signature and remove the need for reordering in the call.
🤖 Prompt for AI Agents
In `@flashinfer/rope/custom_ops.py` around lines 287 - 323, The Python wrapper
_rope_quantize has the parameters cos_sin_cache and pos_ids placed before the
output tensors, but the underlying C++ expects those two after the outputs;
update the function signature of _rope_quantize to list q_rope_out, k_rope_out,
q_nope_out, k_nope_out before cos_sin_cache and pos_ids (so the order is
q_rope_in, k_rope_in, q_nope_in, k_nope_in, q_rope_out, k_rope_out, q_nope_out,
k_nope_out, cos_sin_cache, pos_ids, quant_scale_q, quant_scale_kv, interleave,
enable_pdl) and then adjust the call to get_rope_module().rope_quantize to use
matching parameter order (you can then remove the current manual reordering in
that call); this change affects the _rope_quantize function signature and its
invocation to match the C++ argument order.
| import math | ||
| from typing import Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from .compile import ( | ||
| _get_compiled_kernel, | ||
| _get_compiled_kernel_seq_heads, | ||
| _get_compiled_kernel_with_indptr, | ||
| _get_compiled_cos_sin_cache_kernel, | ||
| _get_compiled_cos_sin_cache_kernel_seq_heads, | ||
| ) |
There was a problem hiding this comment.
Add @flashinfer_api to exported CuTe‑DSL wrapper functions.
These functions are exported in __all__; they should be logged consistently with the rest of the Python API surface.
✅ Proposed decorator + import
@@
-import torch
+import torch
+
+from ...api_logging import flashinfer_api
@@
-def apply_rope_cute_dsl(
+@flashinfer_api
+def apply_rope_cute_dsl(
@@
-def apply_rope_with_indptr_cute_dsl(
+@flashinfer_api
+def apply_rope_with_indptr_cute_dsl(
@@
-def apply_llama31_rope_with_indptr_cute_dsl(
+@flashinfer_api
+def apply_llama31_rope_with_indptr_cute_dsl(
@@
-def apply_rope_with_cos_sin_cache_cute_dsl(
+@flashinfer_api
+def apply_rope_with_cos_sin_cache_cute_dsl(Also applies to: 90-103, 312-326, 455-469, 495-503
🤖 Prompt for AI Agents
In `@flashinfer/rope/kernels/wrappers.py` around lines 23 - 34, The exported
CuTe-DSL wrapper functions in this module (the functions listed in __all__ and
the wrapper functions around the compiled kernels) are missing the
`@flashinfer_api` decorator; import flashinfer_api and add `@flashinfer_api` above
each exported wrapper function (including the functions in the ranges mentioned:
90-103, 312-326, 455-469, 495-503) so they participate in API logging and
crash-safe input capture; ensure you place the decorator directly above the
wrapper function definitions that call the compiled helpers
(_get_compiled_kernel, _get_compiled_kernel_seq_heads,
_get_compiled_kernel_with_indptr, _get_compiled_cos_sin_cache_kernel,
_get_compiled_cos_sin_cache_kernel_seq_heads) and that the import for
flashinfer_api is added to the top-level imports.
| # Validate inputs | ||
| assert q.ndim == 3, f"q must be 3D, got {q.ndim}D" | ||
| assert k.ndim == 3, f"k must be 3D, got {k.ndim}D" | ||
| assert q.size(0) == k.size(0), "q and k must have same nnz" | ||
| assert q.size(2) == k.size(2), "q and k must have same head_dim" | ||
| assert q.is_cuda, "q must be on CUDA" | ||
| assert k.is_cuda, "k must be on CUDA" | ||
| assert pos_ids.is_cuda, "pos_ids must be on CUDA" | ||
|
|
||
| nnz = q.size(0) | ||
| num_qo_heads = q.size(1) | ||
| num_kv_heads = k.size(1) | ||
| head_dim = q.size(2) | ||
|
|
||
| if rotary_dim is None: | ||
| rotary_dim = head_dim | ||
|
|
||
| assert rotary_dim <= head_dim, ( | ||
| f"rotary_dim must be <= head_dim, got {rotary_dim} > {head_dim}" | ||
| ) | ||
| assert rotary_dim % 2 == 0, "rotary_dim must be even" | ||
| assert head_dim % 2 == 0, "head_dim must be even" | ||
|
|
There was a problem hiding this comment.
Add vector‑alignment and rotary‑dim guards for CuTe‑DSL kernels.
The vectorized kernels assume 8‑element alignment; non‑interleaved mode also needs rotary_dim divisible by 16 to avoid cross‑half pairing. Also, apply_rope_with_indptr_cute_dsl lacks the rotary_dim <= head_dim / evenness checks present in apply_rope_cute_dsl.
🐛 Proposed guards (apply here and mirror in the indptr/cos‑sin‑cache wrappers)
@@
assert rotary_dim <= head_dim, (
f"rotary_dim must be <= head_dim, got {rotary_dim} > {head_dim}"
)
assert rotary_dim % 2 == 0, "rotary_dim must be even"
assert head_dim % 2 == 0, "head_dim must be even"
+ if head_dim % 8 != 0:
+ raise ValueError("head_dim must be a multiple of 8 for CuTe-DSL RoPE kernels")
+ if (not interleave) and (rotary_dim % 16 != 0):
+ raise ValueError("rotary_dim must be a multiple of 16 in non-interleaved mode")
@@
if rotary_dim is None:
rotary_dim = head_dim
+ if rotary_dim > head_dim:
+ raise ValueError(f"rotary_dim must be <= head_dim, got {rotary_dim} > {head_dim}")
+ if rotary_dim % 2 != 0 or head_dim % 2 != 0:
+ raise ValueError("rotary_dim and head_dim must be even")
+ if head_dim % 8 != 0:
+ raise ValueError("head_dim must be a multiple of 8 for CuTe-DSL RoPE kernels")
+ if (not interleave) and (rotary_dim % 16 != 0):
+ raise ValueError("rotary_dim must be a multiple of 16 in non-interleaved mode")Also applies to: 379-410
🤖 Prompt for AI Agents
In `@flashinfer/rope/kernels/wrappers.py` around lines 147 - 169, Add
vector-alignment and rotary-dim guards to the CuTe-DSL wrappers: in
apply_rope_cute_dsl and apply_rope_with_indptr_cute_dsl validate that head_dim
is 8-element aligned (e.g., assert head_dim % 8 == 0) and preserve the existing
rotary_dim checks by asserting rotary_dim <= head_dim and rotary_dim % 2 == 0;
additionally, when running in non-interleaved mode enforce rotary_dim is
divisible by 16 (e.g., assert rotary_dim % 16 == 0 when interleaving flag
indicates non-interleaved). Mirror these checks in both the indptr/cos-sin-cache
wrappers so both apply_rope_cute_dsl and apply_rope_with_indptr_cute_dsl perform
the same alignment and rotary-dimension validation.
|
|
||
| See Also | ||
| -------- | ||
| apply_rope_inplace | ||
| """ |
There was a problem hiding this comment.
Incorrect See Also cross-reference.
Line 405 references apply_rope_inplace, but apply_rope_pos_ids's inplace counterpart is apply_rope_pos_ids_inplace.
Fix
See Also
--------
- apply_rope_inplace
+ apply_rope_pos_ids_inplace🤖 Prompt for AI Agents
In `@flashinfer/rope/rope.py` around lines 402 - 406, The See Also cross-reference
is incorrect: update the reference in the docstring of apply_rope_pos_ids to
point to its true inplace counterpart apply_rope_pos_ids_inplace (replace
apply_rope_inplace with apply_rope_pos_ids_inplace) so the documentation
correctly links apply_rope_pos_ids to apply_rope_pos_ids_inplace.
| } else if (by < k_nope_end) { | ||
| // K Non-RoPE processing & cache append | ||
| uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; | ||
| uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; | ||
| uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; | ||
|
|
||
| DType* k_nope_in_ptr; | ||
| if constexpr (IS_MLA) { | ||
| k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset; | ||
| } else { | ||
| k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, | ||
| k_nope_in_stride, k_nope_in_stride_h); | ||
| } | ||
|
|
||
| vec_t<float, vec_size> k_nope_vec; | ||
| k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; | ||
| } | ||
|
|
||
| if constexpr (IS_MLA) { | ||
| QuantType* ckv_ptr = | ||
| paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); | ||
| k_nope_vec.cast_store(ckv_ptr); | ||
| } else { | ||
| QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, | ||
| rope_dim + elem_offset + tx * vec_size); | ||
| k_nope_vec.cast_store(k_ptr); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, verify the file exists and get basic stats
echo "=== File check ==="
ls -la include/flashinfer/rope/kernels.cuh 2>&1 || echo "File not found"
echo ""
echo "=== File line count ==="
wc -l include/flashinfer/rope/kernels.cuh
echo ""
echo "=== Find scale_store_partial_chunk usage ==="
rg -n "scale_store_partial_chunk" include/flashinfer/rope/kernels.cuh -C 3Repository: flashinfer-ai/flashinfer
Length of output: 1009
🏁 Script executed:
#!/bin/bash
# Extract specific sections mentioned in the review
echo "=== Lines 795-824 (K-nope section) ==="
sed -n '795,824p' include/flashinfer/rope/kernels.cuh
echo ""
echo "=== Lines 851-873 (Q-nope section) ==="
sed -n '851,873p' include/flashinfer/rope/kernels.cuhRepository: flashinfer-ai/flashinfer
Length of output: 2447
🏁 Script executed:
#!/bin/bash
# Look at RopeQuantizeKernel implementation for comparison
echo "=== Search for RopeQuantizeKernel function ==="
rg -n "RopeQuantizeKernel" include/flashinfer/rope/kernels.cuh | head -5
echo ""
echo "=== Extract lines around the RopeQuantizeKernel to see the pattern ==="
rg -n "scale_store_partial_chunk" include/flashinfer/rope/kernels.cuh -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1426
🏁 Script executed:
#!/bin/bash
# Search for scale_store_partial_chunk definition
echo "=== Find scale_store_partial_chunk definition ==="
rg -n "scale_store_partial_chunk" include/ -t cuda -t cpp -t cuh --type-add 'cuda:*.cuh'Repository: flashinfer-ai/flashinfer
Length of output: 162
🏁 Script executed:
#!/bin/bash
# Search for scale_store_partial_chunk definition without type filtering
echo "=== Find scale_store_partial_chunk definition ==="
rg -n "scale_store_partial_chunk" include/flashinfer/rope/ -B 2 -A 10
echo ""
echo "=== Search for function definition of scale_store_partial_chunk ==="
rg -n "template.*scale_store_partial_chunk|scale_store_partial_chunk.*{" include/flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 2619
🏁 Script executed:
#!/bin/bash
# Get the function signature and definition
echo "=== Get RopeQuantizeAppendPagedKVCacheKernel function start ==="
rg -n "RopeQuantizeAppendPagedKVCacheKernel" include/flashinfer/rope/kernels.cuh -A 20
echo ""
echo "=== Check for any bounds checking in the paged kernel nope sections ==="
sed -n '700,880p' include/flashinfer/rope/kernels.cuh | grep -n "chunk_valid\|no_rope_dim\|elem_offset"Repository: flashinfer-ai/flashinfer
Length of output: 3893
🏁 Script executed:
#!/bin/bash
# Look at the specific cast_load and cast_store calls in the paged kernel
echo "=== Examine K-nope section more closely ==="
sed -n '790,835p' include/flashinfer/rope/kernels.cuh
echo ""
echo "=== Examine Q-nope section more closely ==="
sed -n '845,880p' include/flashinfer/rope/kernels.cuhRepository: flashinfer-ai/flashinfer
Length of output: 3478
Add bounds checking for non-RoPE sections in RopeQuantizeAppendPagedKVCacheKernel.
The K-nope (line 810) and Q-nope (line 866) sections perform cast_load/cast_store without checking if elem_offset + tx * vec_size < no_rope_dim. When no_rope_dim is not a multiple of rope_chunk_size (= rope_dim), the last nope chunk will have threads reading/writing beyond the valid boundary. The non-paged RopeQuantizeKernel correctly uses scale_store_partial_chunk with chunk_valid bounds for these sections (lines 345–349, 364–368). Apply the same bounds protection to the paged kernel's nope sections, or verify that the memory layout guarantees safe out-of-bounds access.
🤖 Prompt for AI Agents
In `@include/flashinfer/rope/kernels.cuh` around lines 795 - 824,
RopeQuantizeAppendPagedKVCacheKernel currently does unchecked
cast_load/cast_store in the K-nope and Q-nope paths (uses k_nope_in,
elem_offset, tx, vec_size, cast_load/cast_store and paged_kv_like.get_ckv_ptr /
get_k_ptr), which can read/write past no_rope_dim when no_rope_dim %
rope_chunk_size != 0; change these sections to mirror the non-paged kernel by
computing the per-chunk valid count (chunk_valid = max(0, min(vec_size,
no_rope_dim - (elem_offset + tx*vec_size)))) and use the existing partial-store
helper (scale_store_partial_chunk) or otherwise guard loads/stores by
chunk_valid so that only valid elements are read/written; keep the same
quant_scale_kv scaling before storing and apply the check for both K-nope and
Q-nope branches.
| float rope_rcp_scale = 1.0f / rope_scale; | ||
| float rope_rcp_theta = 1.0f / rope_theta; | ||
| float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); | ||
| float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); | ||
|
|
There was a problem hiding this comment.
Guard Llama‑3.1 smooth parameters when freq factors are equal.
high_freq_factor == low_freq_factor triggers division by zero and can yield NaNs; the CuTe‑DSL path already guards this. Please add the same protection here.
🐛 Proposed fix (apply in both Llama31 launcher functions)
- float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
- float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);
+ float smooth_a = 0.f;
+ float smooth_b = 0.f;
+ if (high_freq_factor != low_freq_factor) {
+ smooth_a = old_context_length / (2 * M_PI * (high_freq_factor - low_freq_factor));
+ smooth_b = -1.0f / ((high_freq_factor / low_freq_factor) - 1.0f);
+ }Also applies to: 740-743
🤖 Prompt for AI Agents
In `@include/flashinfer/rope/launchers.cuh` around lines 675 - 679, The
computation of smooth_a and smooth_b can divide by zero when high_freq_factor ==
low_freq_factor; update both Llama‑3.1 launcher functions to guard this by
checking if fabsf(high_freq_factor - low_freq_factor) < 1e-6f (or similar small
epsilon) and in that case set smooth_a = 0.0f and smooth_b = 0.0f (matching the
CuTe‑DSL path); otherwise compute smooth_a and smooth_b with the existing
formulas using high_freq_factor and low_freq_factor. Ensure you apply the same
guard for the other occurrence of smooth_a/smooth_b (the lines around 740-743)
so no division by zero occurs.
| cuda_times = bench_gpu_time(run_cuda) | ||
| cuda_ms = np.median(cuda_times) | ||
|
|
||
| cute_times = bench_gpu_time(run_cute) | ||
| cute_ms = np.median(cute_times) | ||
|
|
||
| return cuda_ms, cute_ms |
There was a problem hiding this comment.
Should be fine, but I recommend enable_cupti=True and dry_run_iters=10 and repeat_iters=30 for RoPe. CUPTI for more accurate kernel time measurements. Iteration count for faster benchmarking.
See bench_gpu_time definition
There was a problem hiding this comment.
The logical organization of new files in rope/kernels/ seems a bit strange. It appears that
compile.pycould be absorbed intowrappers.pyptx_ops.pycould probably be absorbed intohelpers.pyand then renamed asutils.py
These files are not too long per se, so I don't see the need to split out into multiple files.
| Index type for indptr/offsets: "int32" or "int64". Default "int32". | ||
| """ | ||
| dtype = get_cutlass_dtype(dtype_str) | ||
| idtype = get_cutlass_idtype(idtype_str) |
There was a problem hiding this comment.
If you take a very close look at the lifetime of these dtypes, you go from torch.dtype to str to cutlass data types. The intermediate string conversion is an unnecessary step that likely adds a small bit to overhead.
The overhead is generally small and could be 1 us, but RoPe kernels run fast so it might make a small difference.
|
[FAILED] Pipeline #43935806: 12/20 passed |
📌 Description
This PR refactors the RoPE API in general, and also adds CuTe-DSL backend option for the unfused RoPE APIs (flashinfer.rope, apply_rope, apply_rope_inplace, apply_rope_pos_ids, apply_rope_pos_ids_inplace, apply_llama31_rope, apply_llama31_rope_inplace, apply_llama31_rope_pos_ids, apply_llama31_rope_pos_ids_inplace, apply_rope_with_cos_sin_cache, apply_rope_with_cos_sin_cache_inplace)
Refactored structure:
🚗 Performance Analysis
Overall summary: good perf improvements for decode cases (especially llama31 related functions). Some problem sizes are a bit slower; can be investigated later (the default backend is still cuda for now).
Example benchmarking summary for B200 (benchmarking matrix averaged for MHA, MLA, GQA configurations, decode batch 1 - 2048, prefill seq 1k - 32k):
B300 benchmarking
B200 benchmarking
H100 benchmarking
A100 benchmarking
For source script, see
benchmarks/bench_rope_workloads.py.✅ Testing
pytest -x tests/attention/test_rope.py
🔍 Related Issues
#2491
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
backendparameter ("cuda" or "cute-dsl") across all RoPE functions.flashinfer.ropemodule with comprehensive RoPE API consolidation.