Conversation
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds an SM90-targeted Gated Delta Rule (GDN) prefill: new CUDA kernels and launchers, Flat/CUTE/TMA primitives and helpers, Python JIT integration and runtime bindings, a benchmark driver, and tests with Python reference implementations. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python Caller
participant PyModule as flashinfer.gdn_prefill
participant JIT as JIT / nvcc
participant Launcher as gdn_prefill_launcher (host)
participant Kernel as SM90 Delta-rule Kernel
participant TMA as GPU TMA/SMEM
User->>PyModule: chunk_gated_delta_rule(q,k,v,...)
PyModule->>JIT: ensure SM90 module loaded (gen_gdn_prefill_sm90_module)
PyModule->>Launcher: call gdn_prefill_launcher with pointers & flags
Launcher->>Kernel: select template variant and launch kernel
Kernel->>TMA: TMA loads/stores, MMAs, barriers (compute/store)
Kernel-->>Launcher: kernel completes (device memory populated)
Launcher-->>PyModule: return control
PyModule-->>User: return output (and state if requested)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @guangyunh-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the Gated Delta Rule (GDN) attention mechanism, crucial for modern large language models like Qwen-next. It delivers a high-performance CUDA kernel optimized for Hopper GPUs, exposed through a user-friendly Python interface. This enhancement significantly expands the library's capabilities for advanced attention architectures, ensuring both efficiency and accuracy through rigorous testing and benchmarking. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new Gated Delta Rule (GDN) prefill kernel for FlashInfer, targeting NVIDIA Hopper (SM90) architecture. The changes include a Python benchmark script (bench_gdn_prefill.py) to measure performance metrics like TFLOPS and memory bandwidth for the GDN kernel, along with new C++/CUDA files that implement the core GDN logic. The C++ implementation leverages CUTLASS and cute libraries for efficient GPU computations, including collective loads for Q, K, V, alpha, and beta tensors using TMA, a blockwise matrix inverse for state updates, and collective stores for the output. The kernel supports both Grouped Query Attention (GQA) and Grouped Value Attention (GVA) configurations, handles initial state loading, and applies alpha/beta gating. A Python API (flashinfer/gdn_prefill.py) is added to expose the chunk_gated_delta_rule function, which orchestrates the kernel launch and manages tensor allocations. Unit tests (tests/gdn/) and reference implementations are also included to verify correctness, with specific tests for basic and non-full block scenarios, as well as chunked prefill. Review comments highlighted a critical bug in the CHECK macro's stringification and argument passing, dead code in bench_gdn_prefill.py related to peak bandwidth calculation, and an unused use_qk_l2norm_in_kernel parameter in the chunk_gated_delta_rule Python API.
| #define CHECK(expr, msg) \ | ||
| do { \ | ||
| if (!(expr)) { \ | ||
| std::string buffer(1024, '\0'); \ | ||
| sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \ | ||
| throw std::runtime_error(buffer.c_str()); \ | ||
| } \ | ||
| } while (0) |
There was a problem hiding this comment.
The CHECK macro is implemented incorrectly. The use of ##expr is not standard for stringification and will likely cause a compilation error. It should be #expr. Additionally, the sprintf call has a mismatch between the number of format specifiers and the number of arguments provided, which leads to undefined behavior. The msg and __FILE__ arguments are also not passed correctly. This is a critical issue in a utility macro that could hide bugs or cause crashes.
| #define CHECK(expr, msg) \ | |
| do { \ | |
| if (!(expr)) { \ | |
| std::string buffer(1024, '\0'); \ | |
| sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \ | |
| throw std::runtime_error(buffer.c_str()); \ | |
| } \ | |
| } while (0) | |
| #define CHECK(expr, msg) \ | |
| do { \ | |
| if (!(expr)) { \ | |
| std::string buffer(1024, '\0'); \ | |
| sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", #expr, msg, __FILE__, __LINE__); \ | |
| throw std::runtime_error(buffer.c_str()); \ | |
| } \ | |
| } while (0) |
|
|
||
| # Get device info for bandwidth calculation | ||
| props = torch.cuda.get_device_properties(0) | ||
| props.total_memory * 2 / 1e12 # Approximate peak bandwidth |
| initial_state: Optional[torch.Tensor] = None, | ||
| output_final_state: bool = False, | ||
| cu_seqlens: Optional[torch.Tensor] = None, | ||
| use_qk_l2norm_in_kernel: bool = False, |
There was a problem hiding this comment.
The parameter use_qk_l2norm_in_kernel is defined in the function signature and its docstring, but it is not used within the function body. The underlying custom op gdn_prefill is not called with this parameter. This can be misleading for users of the API. This parameter should either be implemented or removed to avoid confusion. The benchmark and test files perform L2 normalization on k in Python before calling this function, which suggests this parameter is indeed unused.
There was a problem hiding this comment.
Actionable comments posted: 13
♻️ Duplicate comments (3)
csrc/flat/common.hpp (1)
11-18: Fix the CHECK macro: stringification and missing comma.This was previously flagged. The macro has two issues:
##exprshould be#exprfor proper stringification- Missing comma between
msgand__FILE__🔎 Proposed fix
#define CHECK(expr, msg) \ do { \ if (!(expr)) { \ std::string buffer(1024, '\0'); \ - sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \ + sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", #expr, msg, __FILE__, __LINE__); \ throw std::runtime_error(buffer.c_str()); \ } \ } while (0)flashinfer/gdn_prefill.py (1)
89-89: Unused parameteruse_qk_l2norm_in_kernel.This parameter is documented but not passed to the kernel. This was already flagged in a previous review.
benchmarks/bench_gdn_prefill.py (1)
199-201: Remove dead code.This expression computes a value but doesn't assign or use it. This was already flagged in a previous review.
🧹 Nitpick comments (21)
csrc/flat/math.hpp (1)
9-19: Consider documenting edge case behavior forn=0.The current implementation returns
1fornext_power_of_two(0)sinceceil_log2(0)returns0. This may be intentional, but could be unexpected. Consider adding a brief comment or static_assert ifn > 0is a precondition.csrc/flat/cute_ext.hpp (1)
8-8: Avoidusing namespacein header files.
using namespace cute;in a header pollutes the namespace of all translation units that include this header, potentially causing name collisions. Consider using explicitcute::prefixes or limiting the directive to a narrower scope.csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (1)
119-133: Include status details in error messages for easier debugging.The generic error messages ("can_implement failed", "initialize failed", "run failed") don't provide insight into why the operation failed. Consider including the
cutlass::Statusvalue or its string representation.🔎 Proposed improvement
status = op.can_implement(arguments); if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("can_implement failed"); + throw std::runtime_error(std::string("can_implement failed: ") + + cutlass::cutlassGetStatusString(status)); } status = op.initialize(arguments, workspace.get(), stream); if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("initialize failed"); + throw std::runtime_error(std::string("initialize failed: ") + + cutlass::cutlassGetStatusString(status)); } status = op.run(stream); if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("run failed"); + throw std::runtime_error(std::string("run failed: ") + + cutlass::cutlassGetStatusString(status)); }csrc/flat/ampere/collective/flat_collective_load.hpp (1)
10-10: Avoidusing namespacein header files.Similar to
cute_ext.hpp, this pollutes the namespace for all includers. Consider using explicitcute::prefixes.csrc/flat/ampere/collective/flat_collective_inverse.hpp (1)
10-10: Avoidusing namespacein header files.Consistent with other headers in this PR, this introduces namespace pollution for all includers.
flashinfer/jit/gdn.py (1)
25-37: Missingsupported_major_versionsto restrict compilation to SM90+.Per coding guidelines, JitSpec should specify
supported_major_versionsto restrict kernel compilation to supported GPU architectures. Since this is an SM90-specific module, it should explicitly declare this constraint.🔎 Proposed fix
def gen_gdn_prefill_sm90_module() -> JitSpec: return gen_jit_spec( name="gdn_prefill_launcher", sources=[ jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu", jit_env.FLASHINFER_CSRC_DIR / "flat" / "prefill" / "prefill_kernel_delta_rule_sm90.cu", ], - extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"], + extra_cuda_cflags=[*sm90a_nvcc_flags, "-DFLAT_SM90A_ENABLED", "-std=c++20"], extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR], + supported_major_versions=[9, 10, 11, 12], )Based on learnings,
supported_major_versionsshould be specified in JitSpec for SM-specific kernels.tests/gdn/conftest.py (1)
34-35: Consider removing commented-out code.The commented-out
multidist_randnline appears to be debug/development artifact. Consider removing it to keep the codebase clean.csrc/flat/math_order_barrier.hpp (1)
31-34: Track the FIXME for persistent scheduler support.The destructor contains a FIXME indicating that the current implementation will have issues with a persistent scheduler. This should be tracked for future work.
Would you like me to open an issue to track the persistent scheduler compatibility work for
OrderedNamedBarriers?csrc/flat/debug.hpp (1)
22-37:DPRINTF_WandDPRINTF_WGhave identical implementations.Both macros print identical output with
[WG%d][W%d][T%-3d]prefix and the same conditionIS_PRINT_BLOCK. This appears to be unintentional duplication. If they're meant to serve different purposes (e.g., warp-level vs warp-group-level filtering), consider differentiating them.Suggested differentiation
If
DPRINTF_WGis intended to be the warp-group variant, perhaps filter to only print from warp-group leaders:#define DPRINTF_WG(fmt, ...) \ - if (IS_PRINT_BLOCK) \ + if (IS_PRINT_BLOCK && threadIdx.x % 128 < 32) \ printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ threadIdx.x, ##__VA_ARGS__)flashinfer/gdn_prefill.py (1)
176-182: Consider simplifying output_state allocation logic.Lines 170-182 allocate
output_statewhether or notoutput_final_stateis True. The conditional on line 176 (elif not output_final_state and output_state is None) could be merged with the first condition since both branches do the same allocation.Simplified allocation
- # Allocate output_state if needed - if output_final_state and output_state is None: - output_state = torch.empty( - (num_seqs, num_sab_heads, head_size, head_size), - dtype=torch.float32, - device=q.device, - ) - elif not output_final_state and output_state is None: - # Still need to allocate since kernel always writes state + # Allocate output_state if not provided (kernel always writes state) + if output_state is None: output_state = torch.empty( (num_seqs, num_sab_heads, head_size, head_size), dtype=torch.float32, device=q.device, )tests/gdn/test_prefill_delta_rule.py (1)
23-28: Unusedblock_sizeparameter.The
block_sizeparameter is defined but never used in_test_prefill_kernel. If it's intended for future use, consider adding a TODO comment. Otherwise, remove it.benchmarks/bench_gdn_prefill.py (1)
25-56: Consider removing unused parameters fromgdn_flopssignature.
num_k_headsandnum_seqsare not used in the FLOPs calculation. While keeping them for API consistency withgdn_bytesis reasonable, documenting why they're unused would be helpful.Option 1: Add underscore prefix to indicate intentionally unused
def gdn_flops( total_seq_len: int, num_q_heads: int, - num_k_heads: int, + _num_k_heads: int, # unused, kept for API consistency num_v_heads: int, head_size: int, - num_seqs: int, + _num_seqs: int, # unused, state ops not counted in TFLOPS ) -> int:csrc/gdn_prefill_launcher.cu (1)
42-68: Device properties are queried redundantly.
cudaGetDevicePropertiesis called both ingdn_prefill(lines 157-161) and again insidegdn_prefill_launcher(lines 43-46). Sincesm_countis already passed as a parameter, consider also passing the device major version to avoid the redundant CUDA API call.🔎 Suggested refactor
Pass the device major version as a parameter to avoid querying device properties twice:
void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, void* v, void* input_state, void* alpha, void* beta, int64_t* cu_seqlens, int64_t num_seqs, int64_t num_q_heads, int64_t num_k_heads, int64_t num_v_heads, int64_t num_o_heads, int64_t head_size, - int64_t packed_seq, float scale, int64_t sm_count, DLDataType dtype, + int64_t packed_seq, float scale, int64_t sm_count, int device_major, DLDataType dtype, cudaStream_t stream) { DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(dtype, DType, [&] { - int dev_id; - cudaGetDevice(&dev_id); - cudaDeviceProp device_properties; - cudaGetDeviceProperties(&device_properties, dev_id); - #if defined(FLAT_SM90A_ENABLED) - if (device_properties.major == 9) { + if (device_major == 9) {csrc/flat/hopper/collective/flat_common.hpp (1)
94-146: Complex byte reordering logic for 1-byte elements.The shuffle-based byte reordering (lines 99-143) implements a specific permutation pattern for converting accumulator layout to operand layout for 1-byte elements. The magic constants (
0x3021,0x2130,0x1054,0x5410,0x3276,0x7632) encode the permutation.This is intricate low-level code. Consider adding a comment explaining the mathematical transformation or referencing documentation for the pattern.
tests/gdn/reference_delta_rule.py (2)
110-111:@torch.inference_modedecorator should include parentheses.The decorator
@torch.inference_modeon lines 110 and 331 is missing parentheses. While this works in recent PyTorch versions, the canonical form is@torch.inference_mode().🔎 Proposed fix
-@torch.inference_mode +@torch.inference_mode() def blockwise_linear_attention(-@torch.inference_mode +@torch.inference_mode() def blockwise_delta_rule(Also applies to: 331-332
138-138: Debug artifact: FIXME comment aboutKVslist.Line 138 contains
KVs = [] # FIXME: kernel debug only. This debug artifact is collected (line 222) but only used in the return value. Consider removing it if not needed for production testing.csrc/flat/hopper/kernel/flat_options.hpp (1)
57-60: Unused function parameters.The
new_optionandoptions_tupleparameters are unused. Since this is a compile-time operation, only the types matter, but the parameters create unnecessary copies and may trigger compiler warnings.🔎 Proposed fix
template <auto kTag, typename Value, typename... Options> -constexpr auto add_option(Option<kTag, Value> new_option, std::tuple<Options...> options_tuple) { +constexpr auto add_option(Option<kTag, Value> /*new_option*/, std::tuple<Options...> /*options_tuple*/) { return add_option_t<kTag, Value, Options...>(); }csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (4)
20-23: Clarify the intent of this workaround macro.The condition
thread_idx > 8192appears to be intentionally unreachable (typical warp groups have far fewer threads). If this is a no-op workaround to prevent compiler optimizations that cause performance issues, consider adding a brief comment explaining the mechanism and referencing any related bug/issue.
123-124: Remove extraneous semicolon.There's a stray semicolon after the
DummyStagestype alias.🔎 Proposed fix
using DummyStages = cutlass::gemm::collective::StageCount<2>; - ;
326-339: Remove commented-out code.The
BetaProcessorstruct is entirely commented out. If it's no longer needed, remove it. If it's planned for future use, consider tracking it in an issue rather than leaving dead code.
1217-1221: Simplifyvalid_seq_lenlogic.The ternary condition can be simplified using
min.🔎 Proposed fix
template <typename WorkDesc> CUTE_DEVICE int valid_seq_len(WorkDesc work_desc, int blk_idx) { - int remain_len = work_desc.seq_len - BlkSeqKV * blk_idx; - return remain_len <= BlkSeqKV ? remain_len : BlkSeqKV; + return min(work_desc.seq_len - BlkSeqKV * blk_idx, BlkSeqKV); }
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (30)
benchmarks/bench_gdn_prefill.pycsrc/flat/ampere/collective/flat_collective_inverse.hppcsrc/flat/ampere/collective/flat_collective_load.hppcsrc/flat/common.hppcsrc/flat/cute_ext.hppcsrc/flat/debug.hppcsrc/flat/hopper/collective/flat_collective_load.hppcsrc/flat/hopper/collective/flat_collective_store.hppcsrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hppcsrc/flat/hopper/collective/flat_common.hppcsrc/flat/hopper/collective/flat_named_barriers.hppcsrc/flat/hopper/device/device_universal.hppcsrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hppcsrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hppcsrc/flat/hopper/kernel/flat_options.hppcsrc/flat/hopper/kernel/flat_tile_scheduler.hppcsrc/flat/math.hppcsrc/flat/math_order_barrier.hppcsrc/flat/prefill/prefill_kernel_delta_rule_sm90.cucsrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuhcsrc/flat/type_traits.hppcsrc/flat/unused.hppcsrc/gdn_prefill_launcher.cuflashinfer/__init__.pyflashinfer/aot.pyflashinfer/gdn_prefill.pyflashinfer/jit/gdn.pytests/gdn/conftest.pytests/gdn/reference_delta_rule.pytests/gdn/test_prefill_delta_rule.py
🧰 Additional context used
📓 Path-based instructions (6)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/gdn/conftest.pytests/gdn/reference_delta_rule.pytests/gdn/test_prefill_delta_rule.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/jit/gdn.pyflashinfer/__init__.pyflashinfer/aot.pyflashinfer/gdn_prefill.py
flashinfer/jit/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/jit/**/*.py: JIT module generators inflashinfer/jit/must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Usegen_jit_spec()function to return a properly configured JitSpec from module generators with appropriatesourcesandextra_cuda_cflags
Specifysupported_major_versionsin JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Files:
flashinfer/jit/gdn.py
flashinfer/__init__.py
📄 CodeRabbit inference engine (CLAUDE.md)
Export new operations in
flashinfer/__init__.pyto make them available as public API
Files:
flashinfer/__init__.py
flashinfer/aot.py
📄 CodeRabbit inference engine (CLAUDE.md)
Register new operations in
flashinfer/aot.pyby calling thegen_*_module()function for AOT (Ahead-Of-Time) pre-compilation support
Files:
flashinfer/aot.py
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/gdn_prefill_launcher.cucsrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu
🧠 Learnings (11)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
tests/gdn/conftest.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
csrc/flat/common.hppcsrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuhcsrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`
Applied to files:
flashinfer/jit/gdn.pyflashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Applied to files:
flashinfer/jit/gdn.pyflashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/jit/gdn.pyflashinfer/aot.pyflashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Applied to files:
flashinfer/jit/gdn.pyflashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/__init__.pyflashinfer/aot.pyflashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
flashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads
Applied to files:
csrc/gdn_prefill_launcher.cu
🧬 Code graph analysis (12)
tests/gdn/conftest.py (1)
flashinfer/utils.py (1)
is_sm90a_supported(531-533)
csrc/flat/cute_ext.hpp (1)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h (1)
layout(29-47)
csrc/flat/math_order_barrier.hpp (1)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (8)
void(279-321)void(495-500)void(503-526)void(529-547)void(550-568)void(572-589)void(592-985)void(988-1215)
flashinfer/aot.py (1)
flashinfer/jit/gdn.py (1)
gen_gdn_prefill_sm90_module(25-37)
csrc/flat/ampere/collective/flat_collective_load.hpp (1)
csrc/flat/hopper/collective/flat_collective_load.hpp (2)
char(17-27)to_string(17-17)
flashinfer/gdn_prefill.py (3)
flashinfer/api_logging.py (1)
flashinfer_api(464-565)flashinfer/jit/gdn.py (1)
gen_gdn_prefill_sm90_module(25-37)csrc/gdn_prefill_launcher.cu (2)
gdn_prefill(71-170)gdn_prefill(71-73)
csrc/flat/hopper/collective/flat_collective_store.hpp (2)
csrc/flat/cute_ext.hpp (2)
alignment_for_swizzle(33-35)alignment_for_swizzle(33-33)csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (2)
args(484-486)args(484-484)
csrc/flat/hopper/device/device_universal.hpp (3)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (8)
args(484-486)args(484-484)can_implement(411-428)can_implement(411-411)initialize_workspace(489-493)initialize_workspace(489-491)to_underlying_arguments(431-482)to_underlying_arguments(431-432)csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (8)
args(174-176)args(174-174)args(178-182)args(178-179)args(184-186)args(184-184)args(197-203)args(197-197)csrc/flat/hopper/collective/flat_collective_store.hpp (4)
initialize_workspace(129-133)initialize_workspace(129-131)to_underlying_arguments(104-119)to_underlying_arguments(104-105)
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (15)
args(484-486)args(484-484)initialize_workspace(489-493)initialize_workspace(489-491)params(495-495)to_underlying_arguments(431-482)to_underlying_arguments(431-432)void(279-321)void(495-500)void(503-526)void(529-547)void(550-568)void(572-589)void(592-985)void(988-1215)csrc/flat/hopper/collective/flat_collective_store.hpp (8)
initialize_workspace(129-133)initialize_workspace(129-131)params(135-135)to_underlying_arguments(104-119)to_underlying_arguments(104-105)void(135-137)void(180-216)void(219-242)csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (6)
params(69-70)params(70-70)params(98-98)params(98-98)to_underlying_arguments(73-96)to_underlying_arguments(73-76)
csrc/flat/ampere/collective/flat_collective_inverse.hpp (2)
csrc/flat/hopper/collective/flat_common.hpp (6)
auto(41-51)auto(55-66)auto(69-71)auto(74-78)auto(81-83)convert_c_layout_to_a_layout(74-74)csrc/flat/cute_ext.hpp (2)
select_tensor(20-30)select_tensor(20-20)
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (3)
csrc/flat/hopper/collective/flat_collective_store.hpp (2)
bool(165-177)params(135-135)csrc/flat/hopper/device/device_universal.hpp (4)
params(57-57)params(57-57)params(145-176)params(145-145)csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)
params(188-190)params(188-188)params(205-205)
tests/gdn/test_prefill_delta_rule.py (3)
tests/gdn/reference_delta_rule.py (2)
exclusive_cumsum(5-9)blockwise_delta_rule(332-478)flashinfer/gdn_prefill.py (2)
gdn_prefill(35-58)chunk_gated_delta_rule(79-200)tests/gdn/conftest.py (1)
qkv_factory(50-51)
🪛 Ruff (0.14.10)
flashinfer/jit/gdn.py
35-35: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
flashinfer/gdn_prefill.py
89-89: Unused function argument: use_qk_l2norm_in_kernel
(ARG001)
benchmarks/bench_gdn_prefill.py
28-28: Unused function argument: num_k_heads
(ARG001)
31-31: Unused function argument: num_seqs
(ARG001)
tests/gdn/test_prefill_delta_rule.py
23-23: Unused function argument: block_size
(ARG001)
211-211: Unused function argument: block_size
(ARG001)
🔇 Additional comments (42)
csrc/flat/unused.hpp (1)
7-19: LGTM!Clean placeholder type implementation. The variadic constructor and assignment operator correctly discard all inputs, making this suitable as a default template parameter for optional components.
csrc/flat/type_traits.hpp (1)
9-35: LGTM!The type traits are well-designed. The
map_to_cutlassspecializations correctly map CUDA native types to their cutlass equivalents, andfirst_non_voidproperly handles the recursive case with a clear static assertion for the all-void error case.csrc/flat/ampere/collective/flat_collective_load.hpp (1)
27-142: LGTM - Well-structured collective load implementation.The
CollectiveLoadVectorclass correctly implements pipelined loads with proper:
- Pipeline acquire/commit semantics (lines 113, 132)
- Memory fence barriers (lines 114, 131)
- Tail handling with masking (lines 102-108)
- Optional vector processor support (lines 123-129)
csrc/flat/ampere/collective/flat_collective_inverse.hpp (1)
59-63: Precision limitation noted - verify acceptable for use cases.The FIXME comment indicates precision issues with half-precision. The static_assert restricts the implementation to
halftypes only. Ensure this limitation is acceptable for the GDN attention use case, or consider addingfloatsupport if higher precision is needed.csrc/flat/hopper/device/device_universal.hpp (2)
56-92: LGTM for workspace and occupancy query logic.The
get_workspace_size,get_grid_shape, andmaximum_active_blocksimplementations correctly delegate to the Kernel and handle dynamic shared memory configuration appropriately.
145-176: LGTM for the static run() entry point.The cluster launch path for SM90+ and fallback kernel launch for older architectures is well-structured. Error handling properly checks both CUDA errors and launch status.
flashinfer/__init__.py (1)
87-87: LGTM! Public API export follows project conventions.The new
chunk_gated_delta_ruleis correctly exported following the establishedfrom .module import symbol as symbolpattern. Based on learnings, this correctly makes the operation available as public API.flashinfer/aot.py (2)
44-44: LGTM! Import follows project conventions.The import of
gen_gdn_prefill_sm90_moduleis correctly placed with other JIT module imports.
537-539: LGTM! AOT registration correctly guarded by SM90 capability.The GDN prefill module is appropriately registered under the
add_miscbranch with thehas_sm90check, following the established pattern for SM90-specific modules. Based on learnings, this correctly registers the operation for AOT pre-compilation support.csrc/flat/hopper/collective/flat_named_barriers.hpp (1)
5-12: LGTM! Clean barrier ID catalog design.The protected
NumBarriersUsed = 4allows derived classes (likeDeltaRuleNamedBarriersmentioned in the summary) to safely extend barrier IDs.Minor observation: index 3 appears unused (indices 0, 1, 2 are defined but
NumBarriersUsed = 4). If intentional for future use or alignment, a brief comment could clarify this.csrc/flat/common.hpp (1)
20-29: LGTM! CUDA_CHECK macro is correctly implemented.The error handling correctly captures the CUDA error name, code, file, and line number.
tests/gdn/conftest.py (2)
9-13: LGTM! Proper SM90a architecture guard.The autouse fixture correctly uses
is_sm90a_supportedfromflashinfer.utilsto skip tests on unsupported GPUs. Based on learnings, this follows the recommended pattern for GPU architecture-specific tests.
31-46: Tensors are generated on GPU via device context, not on CPU.The
gen_qkvfunction is called withinwith torch.device("cuda"):context (lines 46-48 and 237-238 in test_prefill_delta_rule.py). PyTorch's device context manager sets the default device for all tensor operations within that scope, includingtorch.distributions.Normal.sample()andtorch.distributions.Uniform.sample()used bymultidist_randu. Tensors are correctly placed on GPU through implicit device context propagation, not left on CPU. No fix is needed.Likely an incorrect or invalid review comment.
csrc/flat/math_order_barrier.hpp (2)
1-18: LGTM! Well-structured barrier coordination template.The
OrderedNamedBarrierstemplate provides a clean abstraction for multi-warp-group synchronization using named barriers. The conditional type alias forNBId_tcorrectly handles both reserved and plain barrier IDs.
36-78: Barrier synchronization logic is correct.The
ordered_or_waitandnotify_next_blockedmethods correctly implement an ordered barrier pattern where:
- Each WG waits on its assigned barrier
- After completing work, each WG notifies all other WGs' barriers to advance the synchronization state
The detailed comments explaining the barrier state transitions are helpful for understanding the algorithm.
csrc/flat/debug.hpp (1)
47-60: LGTM! TMA descriptor dump is useful for debugging.The
DPRINT_TMA_DESCmacro provides a clean way to dump 128 bytes (32 unsigned ints) of TMA descriptor data in a formatted hex layout, which is valuable for debugging TMA-related issues.flashinfer/gdn_prefill.py (1)
30-75: LGTM! Module caching and op registration follow best practices.The
@functools.cachedecorator ensures the module is built once, and the custom op registration pattern is consistent with the codebase. Based on coding guidelines, this correctly implements module-level caching.csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp (2)
27-35: Persistent scheduler support is blocked with static_assert.The
static_assert(!kIsPersistent, "not implemented")correctly guards against using unimplemented functionality. The commented-outPersistentTileSchedulerline suggests this is planned for future work.Consider removing the commented code (lines 34-35) if persistent scheduler support is not planned for the near term, or add a TODO comment linking to the tracking issue.
37-39: LGTM! Kernel type wiring is correct.The
Kernelalias correctly composesFlatKernelTmaWarpSpecializedDeltaRulewith the appropriateCollectiveMainloopandTileSchedulertypes.csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu (2)
23-69: Consider simplifying the dispatch logic.The current implementation has 16 explicit branches with 2 "unreachable" cases that are indeed unreachable. This could be simplified using bit-packing or a dispatch table, but the current approach is clear and compiler will likely optimize it.
The exhaustive enumeration ensures compile-time instantiation of all needed template combinations.
74-86: LGTM! Explicit instantiations cover required types.The explicit template instantiations for
halfandnv_bfloat16withfloatstate correctly match the supported dtypes in the Python API.benchmarks/bench_gdn_prefill.py (1)
245-250: LGTM! SM90 capability check is correct.The benchmark properly checks for SM90 support before running, providing a clear error message if the device doesn't meet requirements.
csrc/gdn_prefill_launcher.cu (3)
84-94: GQA/GVA validation looks correct.The branching logic correctly validates the head count relationships for both GQA (
num_q_heads >= num_v_heads) and GVA (num_q_heads < num_v_heads) configurations with appropriate ratio checks.
153-155: Default scale calculation is correct.The default scale of
1.0 / sqrt(head_size)is the standard attention scaling factor and is appropriately applied when the user passesscale == 0.0.
99-104: This concern is unfounded.The
CHECK_SHAPEmacro callscheck_shape(), which usesTVM_FFI_ICHECK_EQto validate conditions. On failure,TVM_FFI_ICHECK_EQthrows a C++ exception that immediately unwinds the stack, preventing any code after the check from executing. Therefore, line 103 (data_ptr()) will never be reached if the shape validation fails. The code is safe from the described null pointer access risk.Likely an incorrect or invalid review comment.
csrc/flat/hopper/collective/flat_collective_load.hpp (2)
29-41: Constructor stores references - ensure lifetime correctness.The constructor stores references to
tma_load,pipeline, andstorage. This is appropriate for device-side usage within a kernel's lifetime, but callers must ensure these objects outlive theCollectiveLoadTmainstance.
87-105: Pipeline state is incremented unconditionally.The
++dst_pipeon line 103 is inside theif (lane_predicate == 1)block, meaning only one thread increments the pipeline state. This is intentional for single-producer semantics but verify that all threads in the warp maintain consistent pipeline state views if needed elsewhere.csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)
380-386: Cluster/block synchronization handled correctly.The synchronization logic properly distinguishes between cluster-level (
cluster_arrive_relaxed/cluster_wait) and single-block (__syncthreads) synchronization based onClusterShapesize.
206-220: Warp role enums are well-defined.The
WarpGroupRoleandLdStWarpRoleenums clearly define the responsibilities of each warp/warp-group, enabling clean role-based dispatch in the kernel body.
464-467: Remove this concern—the dealloc vs alloc usage is intentional.Lines 448 and 467 use different register allocation strategies (alloc vs dealloc) because they operate on different register magnitudes.
StateMmaRegisterRequirementis computed as(total_registers - load_registers - aux_registers) / num_state_mma_warp_groups, which typically yields values >= 128. In contrast,AuxMmaRegisterRequirementis fixed at128 - load_registers(88 or 104), which is always < 128. The kernel follows the CUTLASS pattern of usingwarpgroup_reg_deallocfor smaller allocations (< 128 registers) andwarpgroup_reg_allocfor larger ones, as evidenced in related code (e.g., fmha_common.hpp). This asymmetry is correct and reflects different resource needs for the state versus auxiliary MMA warp groups.csrc/flat/hopper/collective/flat_common.hpp (1)
10-29: GEMM accumulator reset logic is correct.The
gemm_reset_zero_accfunction correctly handles both 2x2x1 and 3x3x3 tensor ranks, iterating over the K dimension and resetting the accumulator scale toOneafter each GEMM operation.csrc/flat/hopper/collective/flat_collective_store.hpp (3)
56-58: Hardcoded assumption about element size.Line 57 has
static_assert(sizeof(SmemElementO) == 2), limiting this implementation to 16-bit types (fp16/bf16). This is documented but may need updating if other types are supported in the future.
218-242: Tensormap tail handling uses correct synchronization.The
create_tensormap_for_tailfunction:
- Copies the base tensormap (lines 226-231)
- Synchronizes the warp (line 232)
- Updates the global dimension (lines 234-238)
- Synchronizes again (line 239)
- Issues a release fence (line 241)
This follows the correct pattern for PTX tensormap manipulation.
164-177:can_processlogic handles edge cases correctly.The function correctly identifies when TMA can be used:
- Intermediate full tiles (line 167-169)
- Last tile that's full (line 170, first condition)
- Last sequence where OOB can be handled by TMA (line 170, second condition)
tests/gdn/reference_delta_rule.py (3)
232-308:delta_rulefunction provides element-wise reference implementation.This function implements the delta rule step-by-step for each token, serving as a clear reference for correctness verification. The implementation follows the mathematical formulation from the paper correctly.
360-368: GQA/GVA handling withrepeat_interleaveis correct.The head expansion logic correctly handles both GQA and GVA configurations using
repeat_interleaveto broadcast heads appropriately.
12-29:matmulhelper handles mixed precision correctly.The function correctly upcasts fp16/bf16 inputs to fp32 for computation, then downcasts the result. Note: bf16 inputs return fp32 output (line 25), while fp16 inputs return fp16 output (line 27). Verify this asymmetry is intentional.
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (1)
106-116: Runtime validation is available but not automatically enforced in the execution path.The code defines strict mathematical relationships for GQA and GVA modes through tag selection (line 11-12). A
can_implement()function validates these relationships at runtime (checking that head counts satisfy the GQA/GVA constraints), but this validation is not automatically called during kernel execution—callers must explicitly verify it succeeds before launching. Without this check, mismatched tag selection could allow the divisions at lines 109 and 113 to compute incorrect results or encounter division by zero.Ensure any public API or usage examples require
can_implement()validation before kernel launch.csrc/flat/hopper/kernel/flat_options.hpp (1)
9-13: Clean compile-time options system design.The
Optionstruct with tag-based lookup and theTagenum provide a flexible, type-safe mechanism for kernel configuration. The use of template metaprogramming forfind_option_tandadd_option_tis idiomatic for CUDA/CUTLASS codebases.Also applies to: 62-77
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (3)
274-324: AlphaProcessor implementation looks correct.The warp-level prefix sum using shuffle operations (
__shfl_up_sync,__shfl_sync) for computing cumulative log/product of alpha values is well-structured. The use oflog2f/exp2fpairs and the epsilon (1e-10f) for numerical stability in Line 293 are appropriate for this computation.
591-984: Complex compute function with solid warp-specialized structure.The
computefunction handles the delta rule main loop with proper pipeline synchronization, ordered math barriers for warpgroup coordination, and clear separation between Q@KV, QK@V, and KV update phases. The use ofCUTE_NO_UNROLLfor variable-iteration loops andINLINE_LAMBDAfor hot-path lambdas is appropriate.
904-908: [Your rewritten review comment text here]
[Exactly ONE classification tag]
| template <typename TensorT> | ||
| CUTE_DEVICE void blockwise_diagonal_inversed_4x4_to_8x8(TensorT&& mat) { | ||
| constexpr auto L = | ||
| typename std::remove_const_t<std::remove_reference_t<TensorT>>::layout_type{}; | ||
| static_assert(rank(L) == 2); | ||
| static_assert(size<0>(L) == 8); | ||
| static_assert(size<1>(L) == 8); | ||
| auto mat_NxN_2x2 = flat_divide(std::forward<TensorT>(mat), Shape<_4, _4>{}); | ||
|
|
||
| // FIXME: implement | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
rg -n "blockwise_diagonal_inversed_4x4_to_8x8" --type cppRepository: flashinfer-ai/flashinfer
Length of output: 202
Remove unimplemented dead code.
blockwise_diagonal_inversed_4x4_to_8x8 is not called anywhere in the codebase and contains only a FIXME placeholder. Remove this function to avoid confusion.
🤖 Prompt for AI Agents
csrc/flat/ampere/collective/flat_collective_inverse.hpp around lines 164-174:
remove the unused, unimplemented function blockwise_diagonal_inversed_4x4_to_8x8
entirely; delete its template declaration, body and related local variable so
there is no dead code or FIXME left in the file, and run a quick grep to confirm
it’s not referenced elsewhere before committing.
| using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>; | ||
| using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>; |
There was a problem hiding this comment.
Duplicate type alias declaration.
CopyOpO_S2R is defined twice on consecutive lines. This appears to be a copy-paste error. The second definition shadows the first, though both are identical so it compiles without error.
🔎 Proposed fix
using CopyOpD_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
using CopyOpC_S2R = std::conditional_t<is_col_major, SM75_U32x4_LDSM_N, SM75_U16x8_LDSM_T>;
using CopyOpA_S2R = std::conditional_t<is_col_major, SM75_U32x2_LDSM_N, SM75_U16x4_LDSM_T>;
using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
- using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>; | |
| using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>; | |
| using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>; |
🤖 Prompt for AI Agents
In csrc/flat/ampere/collective/flat_collective_inverse.hpp around lines 370-371
there is a duplicate type alias declaration for CopyOpO_S2R on consecutive
lines; remove the redundant second declaration (or the duplicate line) so the
alias is declared only once, keeping the existing conditional_t definition
intact and preserving surrounding formatting.
| template <int... Is, typename Tensor> | ||
| __forceinline__ __host__ __device__ constexpr auto select_tensor(Tensor&& t) { | ||
| if constexpr (is_composed_layout<decltype(t.layout())>::value) { | ||
| return make_tensor( | ||
| std::forward<Tensor>(t).data(), | ||
| make_composed_layout(std::forward<Tensor>(t).layout().layout_a(), | ||
| std::forward<Tensor>(t).layout().offset(), | ||
| select<Is...>(std::forward<Tensor>(t).layout().layout_b()))); | ||
| } else { | ||
| return make_tensor(std::forward<Tensor>(t).data(), select<Is...>(t.layout())); | ||
| } | ||
| } |
There was a problem hiding this comment.
Multiple std::forward calls on the same parameter may cause undefined behavior.
In the composed layout branch (lines 22-26), std::forward<Tensor>(t) is called multiple times on the same parameter. If t is an rvalue, the first forward may move from it, leaving subsequent accesses in an undefined state.
🔎 Proposed fix
template <int... Is, typename Tensor>
__forceinline__ __host__ __device__ constexpr auto select_tensor(Tensor&& t) {
if constexpr (is_composed_layout<decltype(t.layout())>::value) {
+ auto& layout = t.layout();
return make_tensor(
std::forward<Tensor>(t).data(),
- make_composed_layout(std::forward<Tensor>(t).layout().layout_a(),
- std::forward<Tensor>(t).layout().offset(),
- select<Is...>(std::forward<Tensor>(t).layout().layout_b())));
+ make_composed_layout(layout.layout_a(),
+ layout.offset(),
+ select<Is...>(layout.layout_b())));
} else {
return make_tensor(std::forward<Tensor>(t).data(), select<Is...>(t.layout()));
}
}🤖 Prompt for AI Agents
In csrc/flat/cute_ext.hpp around lines 19 to 30, the composed-layout branch
calls std::forward<Tensor>(t) multiple times which can move from an rvalue more
than once and lead to undefined behavior; fix this by forwarding t exactly once
into a local variable (e.g., auto&& tmp = std::forward<Tensor>(t)) at the top of
the function and then use that local tmp for data(), layout(), layout_a(),
offset(), and layout_b(), ensuring all subsequent accesses use the single
forwarded object.
| template <class ProblemShape> | ||
| static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { | ||
| auto ratio = problem_size.num_q_heads > problem_size.num_v_heads | ||
| ? problem_size.num_q_heads / problem_size.num_v_heads | ||
| : problem_size.num_v_heads / problem_size.num_q_heads; | ||
|
|
||
| constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value; | ||
|
|
||
| bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) && | ||
| (problem_size.num_q_heads == ratio * problem_size.num_k_heads) && | ||
| (problem_size.num_q_heads == ratio * problem_size.num_v_heads); | ||
|
|
||
| bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) && | ||
| (problem_size.num_v_heads == ratio * problem_size.num_q_heads) && | ||
| (problem_size.num_v_heads == ratio * problem_size.num_k_heads); | ||
| return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && | ||
| (problem_size.head_size <= get<2>(TileShape{})) && | ||
| ((problem_size.head_size % Alignment) == 0); | ||
| } |
There was a problem hiding this comment.
Potential division by zero in ratio calculation.
If problem_size.num_v_heads is zero (or num_q_heads when it's smaller), the ratio calculation will produce undefined behavior. Although the caller likely validates head counts, adding a guard or assertion would make this safer.
🔎 Proposed fix
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
+ if (problem_size.num_q_heads == 0 || problem_size.num_v_heads == 0) {
+ return false;
+ }
auto ratio = problem_size.num_q_heads > problem_size.num_v_heads
? problem_size.num_q_heads / problem_size.num_v_heads
: problem_size.num_v_heads / problem_size.num_q_heads;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| template <class ProblemShape> | |
| static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { | |
| auto ratio = problem_size.num_q_heads > problem_size.num_v_heads | |
| ? problem_size.num_q_heads / problem_size.num_v_heads | |
| : problem_size.num_v_heads / problem_size.num_q_heads; | |
| constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value; | |
| bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) && | |
| (problem_size.num_q_heads == ratio * problem_size.num_k_heads) && | |
| (problem_size.num_q_heads == ratio * problem_size.num_v_heads); | |
| bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) && | |
| (problem_size.num_v_heads == ratio * problem_size.num_q_heads) && | |
| (problem_size.num_v_heads == ratio * problem_size.num_k_heads); | |
| return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && | |
| (problem_size.head_size <= get<2>(TileShape{})) && | |
| ((problem_size.head_size % Alignment) == 0); | |
| } | |
| template <class ProblemShape> | |
| static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { | |
| if (problem_size.num_q_heads == 0 || problem_size.num_v_heads == 0) { | |
| return false; | |
| } | |
| auto ratio = problem_size.num_q_heads > problem_size.num_v_heads | |
| ? problem_size.num_q_heads / problem_size.num_v_heads | |
| : problem_size.num_v_heads / problem_size.num_q_heads; | |
| constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value; | |
| bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) && | |
| (problem_size.num_q_heads == ratio * problem_size.num_k_heads) && | |
| (problem_size.num_q_heads == ratio * problem_size.num_v_heads); | |
| bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) && | |
| (problem_size.num_v_heads == ratio * problem_size.num_q_heads) && | |
| (problem_size.num_v_heads == ratio * problem_size.num_k_heads); | |
| return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && | |
| (problem_size.head_size <= get<2>(TileShape{})) && | |
| ((problem_size.head_size % Alignment) == 0); | |
| } |
🤖 Prompt for AI Agents
In
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
around lines 410-428, the ratio calculation can divide by zero when num_v_heads
or num_q_heads is zero; add a guard before computing ratio to return false (or
assert) if either problem_size.num_q_heads == 0 or problem_size.num_v_heads ==
0, then compute ratio safely (integer division is fine here) and proceed; ensure
the early-return/validation is placed before any use of ratio so the function
never performs a division by zero.
| bool is_initialized(bool set = false) { | ||
| static bool initialized = false; | ||
| if (set) initialized = true; | ||
| return initialized; | ||
| } |
There was a problem hiding this comment.
Critical bug: static local variable causes shared state across all instances.
The is_initialized function uses a static bool, meaning all instances of Universal<Kernel_> share the same initialization flag. If you create two Universal objects and initialize one, the second will incorrectly believe it's already initialized and skip cudaFuncSetAttribute.
🔎 Proposed fix: use instance member instead of static local
private:
/// Kernel API parameters object
Params params_;
+ bool initialized_ = false;
- bool is_initialized(bool set = false) {
- static bool initialized = false;
- if (set) initialized = true;
- return initialized;
- }
+ bool is_initialized() const { return initialized_; }
+ void set_initialized() { initialized_ = true; }Then update usages in initialize():
- if (is_initialized()) return Status::kSuccess;
+ if (is_initialized()) return Status::kSuccess;
// ... smem setup code ...
- is_initialized(true);
+ set_initialized();Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In csrc/flat/hopper/device/device_universal.hpp around lines 30-34, the
is_initialized(bool set=false) function uses a static local variable causing
shared initialization state across all Universal<Kernel_> instances; replace the
static local with a non-static instance member (e.g., bool initialized_
initialized to false in the class) and update initialize() and any call sites to
use this member (use initialized_ = true when setting and return initialized_
when querying) so each object tracks its own initialization state.
| .ptr_Q = (T*)q, .dQ = {q_tok_stride, _1{}, q_head_stride}, | ||
| .ptr_K = (T*)k, .dK = {k_tok_stride, _1{}, k_head_stride}, | ||
| .ptr_V = (T*)v, .dV = {v_tok_stride, _1{}, v_head_stride}, | ||
| .ptr_O = (T*)output, .dO = {o_tok_stride, _1{}, o_head_stride}, |
There was a problem hiding this comment.
Casting away const from input pointers is unsafe.
The input pointers q, k, v are declared as TQKV const*, but the casts (T*)q, (T*)k, (T*)v strip the const qualifier. If the kernel unexpectedly writes to these buffers, this results in undefined behavior. If the API genuinely requires non-const pointers but guarantees read-only access, consider documenting this or using const_cast explicitly to make the intent clear.
🔎 Suggested improvement
If the underlying API truly requires mutable pointers but won't modify them:
- .ptr_Q = (T*)q, .dQ = {q_tok_stride, _1{}, q_head_stride},
- .ptr_K = (T*)k, .dK = {k_tok_stride, _1{}, k_head_stride},
- .ptr_V = (T*)v, .dV = {v_tok_stride, _1{}, v_head_stride},
+ // SAFETY: cutlass API requires non-const but these are read-only
+ .ptr_Q = const_cast<T*>(reinterpret_cast<T const*>(q)),
+ .ptr_K = const_cast<T*>(reinterpret_cast<T const*>(k)),
+ .ptr_V = const_cast<T*>(reinterpret_cast<T const*>(v)),Better yet, if possible, update the Arguments type to accept T const* for input tensors.
Committable suggestion skipped: line range outside the PR's diff.
| @pytest.mark.parametrize("beta", [False, True]) | ||
| @pytest.mark.parametrize("alpha", [False, True]) | ||
| @pytest.mark.parametrize("scale", [1.0, "auto"]) | ||
| @pytest.mark.parametrize("head_size", [128]) | ||
| @pytest.mark.parametrize( | ||
| "num_q_heads, num_k_heads, num_v_heads", | ||
| [(1, 1, 1), (4, 1, 1), (3, 3, 3), (6, 2, 2), (1, 1, 2), (2, 2, 4)], | ||
| ) | ||
| @pytest.mark.parametrize("seq_lens", [[64], [128], [256], [256, 256], [64, 128, 512]]) | ||
| @pytest.mark.parametrize("block_size", [64]) | ||
| @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) |
There was a problem hiding this comment.
Missing SM90 architecture check per coding guidelines.
The coding guidelines specify using flashinfer.utils functions like is_sm90a_supported to skip tests on unsupported GPU architectures. This test requires SM90 but doesn't check for it.
Suggested fix
Add at the top of the test functions or as a module-level skip:
from flashinfer.utils import is_sm90a_supported
# At module level or in each test:
@pytest.mark.skipif(not is_sm90a_supported(), reason="GDN requires SM90a")As per coding guidelines, test implementations should use flashinfer.utils functions to skip tests on unsupported GPU architectures.
🤖 Prompt for AI Agents
In tests/gdn/test_prefill_delta_rule.py around lines 117 to 127, the
parametrized test lacks a check to skip execution on GPUs without SM90 support;
add an import from flashinfer.utils (is_sm90a_supported) and apply a
pytest.mark.skipif(not is_sm90a_supported(), reason="GDN requires SM90a") either
at the module top for all tests or directly on the affected test function(s) so
the test is skipped on unsupported architectures.
There was a problem hiding this comment.
It was done in tests/gdn/conftest.py
There was a problem hiding this comment.
@yzh119, you're absolutely right! I apologize for the incorrect comment.
I can see from tests/gdn/conftest.py that there is an autouse fixture skip_if_not_sm90a that automatically applies the SM90a check to all tests in the tests/gdn/ directory using is_sm90a_supported() from flashinfer.utils. This is actually a better approach than adding the decorator to each individual test function, as it:
- ✅ Uses the correct
flashinfer.utils.is_sm90a_supported()function per the coding guidelines - ✅ Applies the skip automatically to all GDN tests without repetition
- ✅ Keeps the test code cleaner and more maintainable
The implementation correctly follows the architecture check guidelines. Thank you for pointing this out!
🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
csrc/flat/prefill/prefill_kernel.hpp (1)
14-22: Add documentation for this public API function.Consider adding function-level documentation (e.g., Doxygen-style comments) that describes:
- The purpose of the delta rule prefill kernel
- Template parameter requirements and constraints
- Parameter descriptions (especially the distinction between num_v_heads and num_o_heads)
- Expected behavior and preconditions
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/flat/prefill/prefill_kernel.hpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
csrc/flat/prefill/prefill_kernel.hpp (3)
1-6: LGTM!Standard header guard and includes are appropriate for this kernel launcher declaration.
7-10: LGTM!The forward declaration is a good practice to minimize compilation dependencies while allowing the launcher signature to reference the SM90 architecture tag.
20-21: These parameters serve different purposes in delta rule grouped attention and are intentional, not a copy-paste error.The
num_v_headsparameter dimensions the value computation and controls attention grouping logic (GQA/GVA patterns), whilenum_o_headsdimensions the output and state tensor storage. In grouped attention scenarios, these can legitimately differ, making both parameters necessary. The codebase includes explicit documentation linkingnum_o_headstonum_sab_heads(state/alpha/beta), confirming the intentional design.
|
|
||
| namespace flat { | ||
|
|
||
| template <typename ArchTag, // TODO: hide this |
There was a problem hiding this comment.
Resolve or remove the TODO comment before merging.
The TODO comment indicates the ArchTag template parameter should be hidden from the public API, suggesting the API design is not finalized. Consider one of the following approaches:
- Move ArchTag determination inside the function using type traits or runtime detection
- Use a non-template wrapper function that internally dispatches to the appropriate architecture
- If the current design is final, remove the TODO
🤖 Prompt for AI Agents
In csrc/flat/prefill/prefill_kernel.hpp around line 14, remove or resolve the
TODO about hiding the ArchTag template parameter: either (A) move architecture
selection inside the function by deriving ArchTag with type traits or runtime
detection and make the function non-templated on ArchTag, (B) provide a
non-template public wrapper that dispatches to template instantiations of the
implementation based on detected/derived architecture, or (C) if the template
parameter is intended to remain part of the public API, delete the TODO and add
a clarifying comment documenting the design decision; implement one of these
options and update comments/tests accordingly.
|
/bot run |
|
[FAILED] Pipeline #41026587: 9/20 passed |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/gdn_prefill.py (1)
91-91: [Duplicate] Theuse_qk_l2norm_in_kernelparameter is unused.This parameter is defined in the function signature and documented, but it is not passed to the underlying
gdn_prefillkernel call at line 186. Tests and benchmarks perform L2 normalization onkin Python before calling this function, confirming this parameter is not implemented.
🧹 Nitpick comments (1)
tests/gdn/test_prefill_delta_rule.py (1)
23-23: Unusedblock_sizeparameter in test helpers.The
block_sizeparameter is accepted by both_test_prefill_kernel(line 23) and_test_chunked_prefill(line 211) but is never used within the function bodies. This parameter is included in the test parametrization but serves no purpose in the current implementation.🔎 Suggested fix
If
block_sizeis not needed for validation or other purposes, remove it from the function signatures and test parametrization:def _test_prefill_kernel( qkv_factory, dtype: str, num_q_heads: int, num_k_heads: int, num_v_heads: int, head_size: int, - block_size: int, seq_lens: list[int], scale: float, alpha: bool, beta: bool, seed: int | None = None, ):And remove from parametrization:
-@pytest.mark.parametrize("block_size", [64])Apply the same changes to
_test_chunked_prefillandtest_chunked_prefill.Also applies to: 211-211
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/gdn_prefill.pytests/gdn/test_prefill_delta_rule.py
🧰 Additional context used
📓 Path-based instructions (2)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/gdn_prefill.py
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/gdn/test_prefill_delta_rule.py
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
tests/gdn/test_prefill_delta_rule.py
🧬 Code graph analysis (1)
tests/gdn/test_prefill_delta_rule.py (3)
tests/gdn/reference_delta_rule.py (2)
exclusive_cumsum(5-9)blockwise_delta_rule(332-478)flashinfer/gdn_prefill.py (2)
gdn_prefill(37-60)chunk_gated_delta_rule(81-202)tests/gdn/conftest.py (1)
qkv_factory(50-51)
🪛 Ruff (0.14.10)
flashinfer/gdn_prefill.py
91-91: Unused function argument: use_qk_l2norm_in_kernel
(ARG001)
tests/gdn/test_prefill_delta_rule.py
23-23: Unused function argument: block_size
(ARG001)
211-211: Unused function argument: block_size
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| scale (Optional[float]): | ||
| Scale factor for the attention scores. | ||
| If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``. |
There was a problem hiding this comment.
Inconsistency between docstring and implementation for scale default.
The docstring (lines 118-120) states that if scale is not provided, it "defaults to 1 / sqrt(head_size)", but the implementation (line 196) passes 0.0 to the kernel when scale is None.
If the kernel interprets 0.0 as a sentinel value to apply the default scaling internally, this behavior should be documented in the docstring. Otherwise, either compute the default scale here or update the docstring to reflect the actual behavior.
🔎 Suggested clarification
Option 1: If kernel handles the default, update the docstring:
scale (Optional[float]):
- Scale factor for the attention scores.
- If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``.
+ Scale factor for the attention scores.
+ If None or 0.0, the kernel applies the default scale of ``1 / sqrt(head_size)``. Default: ``None``.Option 2: If the default should be computed here:
- scale if scale is not None else 0.0,
+ scale if scale is not None else (1.0 / (head_size ** 0.5)),Also applies to: 196-196
🤖 Prompt for AI Agents
In flashinfer/gdn_prefill.py around lines 118-120 (docstring) and line 196 (call
site), the docstring claims scale defaults to 1/sqrt(head_size) but the code
passes 0.0 when scale is None; make them consistent by either computing the
default locally or documenting the 0.0 sentinel. Fix by one of two concise
changes: (A) if you want the module to compute the default, replace the None
branch so that when scale is None you compute scale = 1.0 / math.sqrt(head_size)
and pass that to the kernel (ensure head_size is available), or (B) if the
kernel internally treats 0.0 as "use default", update the docstring to state
that passing 0.0 signals the kernel to use 1/sqrt(head_size) internally; apply
the chosen change at line 196 and update the docstring accordingly.
|
[SUCCESS] Pipeline #41058613: 13/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 8
♻️ Duplicate comments (12)
csrc/flat/hopper/device/device_universal.hpp (2)
45-49: Critical bug: static local variable causes shared state across all instances.This issue was already flagged in a previous review. The
is_initializedfunction uses astatic bool, meaning all instances ofUniversal<Kernel_>share the same initialization flag.
158-160: Typo in comment: "to_underling_arguments" should be "to_underlying_arguments".This issue was already flagged in a previous review.
csrc/flat/ampere/collective/flat_collective_inverse.hpp (2)
179-189: Remove unimplemented dead code.
blockwise_diagonal_inversed_4x4_to_8x8contains only a FIXME placeholder and is not called anywhere in the codebase. Remove this function to avoid confusion.🤖 Prompt for AI Agents
csrc/flat/ampere/collective/flat_collective_inverse.hpp around lines 179-189: remove the unused, unimplemented function blockwise_diagonal_inversed_4x4_to_8x8 entirely; delete its template declaration, body and related local variable so there is no dead code or FIXME left in the file.
385-386: Duplicate type alias declaration.
CopyOpO_S2Ris defined twice on consecutive lines. This appears to be a copy-paste error. The second definition shadows the first, though both are identical so it compiles without error.🔎 Proposed fix
using CopyOpD_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>; using CopyOpC_S2R = std::conditional_t<is_col_major, SM75_U32x4_LDSM_N, SM75_U16x8_LDSM_T>; using CopyOpA_S2R = std::conditional_t<is_col_major, SM75_U32x2_LDSM_N, SM75_U16x4_LDSM_T>; using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>; - using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;csrc/flat/common.hpp (1)
26-33: Critical: Fix CHECK macro syntax errors.Line 30 contains two syntax errors that will cause compilation failure:
##exprshould be#exprfor proper stringification- Missing comma between
msgand__FILE__🔎 Proposed fix
#define CHECK(expr, msg) \ do { \ if (!(expr)) { \ std::string buffer(1024, '\0'); \ - sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \ + sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", #expr, msg, __FILE__, __LINE__); \ throw std::runtime_error(buffer.c_str()); \ } \ } while (0)flashinfer/gdn_prefill.py (2)
118-120: Clarify the scale default mechanism in the docstring.The docstring states the default is
1/sqrt(head_size), but the implementation passes0.0to the kernel (line 196), which the kernel interprets as a sentinel to apply the default. Consider clarifying this in the docstring to avoid confusion.🔎 Suggested clarification
scale (Optional[float]): Scale factor for the attention scores. - If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``. + Scale factor for the attention scores. + If None (or 0.0), the kernel applies the default scale of ``1 / sqrt(head_size)``. Default: ``None``.
91-91: Unused parameteruse_qk_l2norm_in_kernel.This parameter is defined in the signature and documented, but it's never used in the function body or passed to the kernel. Consider either implementing the functionality or removing the parameter to avoid user confusion.
csrc/flat/prefill/prefill_kernel.hpp (1)
29-37: TODO: Consider tracking the API design decision.The TODO comment indicates
ArchTagshould be hidden from the public API. If this is intentional for now, consider creating an issue to track this future improvement, or document why the current design was chosen.csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (1)
119-124: Casting awayconstfrom input pointers.The casts
(T*)q,(T*)k,(T*)vstrip the const qualifier from input pointers. If the underlying CUTLASS API requires mutable pointers but guarantees read-only access, consider adding a comment documenting this safety invariant.csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (2)
117-118: Stray semicolon on line 118.There's an extraneous semicolon after the
seq_idxdeclaration.🔎 Proposed fix
int32_t seq_idx; - ; int32_t q_head_idx;
133-145: Out-of-bounds access ifseq_idxis invalid.Lines 133-134 access
problem_size.cu_seqlens[seq_idx]andproblem_size.cu_seqlens[seq_idx + 1]before validating whether the block should be scheduled. IfblockIdx.xresults in an invalidseq_idx, this causes undefined behavior. Consider moving thescheduledcheck before accessingcu_seqlens.benchmarks/bench_gdn_prefill.py (1)
199-201: Dead code: unused calculation result.This calculation is not assigned to any variable or used. Remove the dead code.
🔎 Suggested fix
# Get device info for bandwidth calculation props = torch.cuda.get_device_properties(0) - props.total_memory * 2 / 1e12 # Approximate peak bandwidth
🧹 Nitpick comments (18)
csrc/flat/hopper/device/device_universal.hpp (2)
75-107: Consider extracting duplicated cudaFuncSetAttribute logic.The dynamic shared memory configuration code (lines 82-92 and 128-138) is duplicated between
maximum_active_blocks()andinitialize(). Both blocks checksmem_size >= (48 << 10)and callcudaFuncSetAttributewith identical error handling.🔎 Proposed refactor: extract into a helper method
+ private: + static Status configure_smem_size() { + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel<Kernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + return Status::kSuccess; + }Then use
configure_smem_size()in both locations.Also applies to: 124-140
160-191: Consider making params const in static run() method.The
paramsparameter is passed as a non-const reference but appears to be read-only within the method. Making itconst Params&would better express intent and prevent accidental modifications.🔎 Proposed change
- static Status run(Params& params, cudaStream_t stream = nullptr) { + static Status run(Params const& params, cudaStream_t stream = nullptr) {csrc/flat/ampere/collective/flat_collective_inverse.hpp (1)
74-82: Clarify the precision concern.The FIXME at Line 76 indicates "precision is not good due to half" but provides no guidance on acceptable error bounds or whether this is a known limitation. For production code, either:
- Document the expected precision characteristics and acceptable use cases, or
- Track this as a TODO for future FP32 support with a reference to an issue
csrc/flat/hopper/collective/flat_common.hpp (2)
25-50: Document the accumulator reset behavior.The functions
gemm_reset_zero_accandgemm_zero_acchave subtle side effects onatom.accumulate_:
gemm_zero_acczeros the accumulator initially, then accumulates (adds) on subsequent k_blocksgemm_reset_zero_accassumes the accumulator is already initialized and switches to accumulation after the first k_blockThe relationship between these functions and their naming is not immediately clear. Consider:
- Adding function-level documentation explaining when to use each
- Clarifying that
atom.accumulate_is modified as a side effect- Documenting that the caller's
atomstate will be changed toScaleOut::Oneafter these calls
23-23: Consider avoidingusing namespacein header files.While convenient,
using namespace cute;in a header file brings all cute symbols into any translation unit that includes this header, which can lead to name collisions and reduced readability.Consider using explicit
cute::qualifications or targeted using-declarations for specific symbols if the project style permits.flashinfer/jit/gdn.py (1)
35-35: Consider iterable unpacking for better style.While the current list concatenation works correctly, iterable unpacking is more idiomatic Python.
🔎 Proposed refactor
- extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"], + extra_cuda_cflags=[*sm90a_nvcc_flags, "-DFLAT_SM90A_ENABLED", "-std=c++20"],csrc/flat/math_order_barrier.hpp (1)
47-49: Address persistent scheduler compatibility.The FIXME comment indicates a known issue with persistent schedulers. The destructor is currently empty, which may cause resource leaks or incorrect barrier states when kernels are reused.
Would you like me to help design a solution for persistent scheduler compatibility, or should this be tracked in a separate issue?
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (1)
121-131: Consider adding assertions for head ratio invariants.The code assumes
num_q_heads >= num_v_headsfor GQA andnum_v_heads >= num_q_headsfor GVA. If these invariants are violated, the divisions on lines 124 and 128 may produce incorrect results (e.g., zero divisor from integer division). Consider adding debug assertions to catch API misuse early.🔎 Suggested defensive check
if constexpr (std::is_same_v<GroupingTag, GQATag>) { + assert(params.num_q_heads >= params.num_v_heads && "GQA requires num_q_heads >= num_v_heads"); seq_idx = blockIdx.x / params.num_q_heads; q_head_idx = blockIdx.x % params.num_q_heads; v_head_idx = q_head_idx / (params.num_q_heads / params.num_v_heads); } else if constexpr (std::is_same_v<GroupingTag, GVATag>) { + assert(params.num_v_heads >= params.num_q_heads && "GVA requires num_v_heads >= num_q_heads"); seq_idx = blockIdx.x / params.num_v_heads; v_head_idx = blockIdx.x % params.num_v_heads; q_head_idx = v_head_idx / (params.num_v_heads / params.num_q_heads);csrc/flat/hopper/collective/flat_collective_store.hpp (1)
72-73: Consider documenting the 16-bit type constraint.The
static_assert(sizeof(SmemElementO) == 2)enforces that only 16-bit types (e.g.,half,bfloat16) are supported. A brief comment explaining this constraint would help future maintainers understand the design choice.+ // SW32 swizzle layout requires 16-bit elements (half/bfloat16) static_assert(sizeof(SmemElementO) == 2);csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (1)
134-148: Consider including status codes in error messages.The error messages are generic (e.g., "can_implement failed"). Including the
cutlass::Statuscode would help debugging.🔎 Suggested improvement
status = op.can_implement(arguments); if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("can_implement failed"); + throw std::runtime_error("can_implement failed with status: " + + std::to_string(static_cast<int>(status))); }csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp (1)
42-54: LGTM with a minor cleanup suggestion.The kernel builder specialization is well-structured. The
static_assertfor persistent mode clearly indicates work-in-progress.Consider removing the commented-out code (lines 49-50) as it adds noise. If this is intentional for future reference, a TODO comment would be clearer.
🔎 Suggested cleanup
using GroupingTag = std::conditional_t<kIsGVA, GVATag, GQATag>; using TileScheduler = flat::kernel::IndividualTileScheduler<GroupingTag>; - // using TileScheduler = std::conditional_t<kIsPersistent, flat::kernel::PersistentTileScheduler, - // flat::kernel::IndividualTileScheduler>; using Kernel = flat::kernel::FlatKernelTmaWarpSpecializedDeltaRule<CollectiveMainloop, TileScheduler, Options>;csrc/gdn_prefill_launcher.cu (2)
43-46: Redundant CUDA device property queries.
gdn_prefillalready queries device properties (lines 157-160) and passessm_countto the launcher. The device major version check here could use the same mechanism to avoid duplicate queries.Consider passing the device major version or making the check in
gdn_prefillbefore calling the launcher, or caching the result.
157-161: Missing CUDA error checks.The
cudaGetDeviceandcudaGetDevicePropertiescalls do not check return values. While failures are rare, unchecked errors could lead to undefined behavior if the device ID is invalid.🔎 Suggested improvement
int dev_id; - cudaGetDevice(&dev_id); + CUDA_CHECK(cudaGetDevice(&dev_id)); cudaDeviceProp device_properties; - cudaGetDeviceProperties(&device_properties, dev_id); + CUDA_CHECK(cudaGetDeviceProperties(&device_properties, dev_id)); int32_t sm_count = device_properties.multiProcessorCount;csrc/flat/hopper/collective/flat_collective_load.hpp (1)
98-98: Clarify cluster support limitation.The comment "do not support cluster" could be more informative. Consider adding context about why cluster support is not implemented or when it might be added.
benchmarks/bench_gdn_prefill.py (1)
25-56: Unused parameters ingdn_flopsfunction.The parameters
num_k_headsandnum_seqsare unused. If they're kept for API consistency withgdn_bytes, add a leading underscore or comment explaining why they're present but unused.🔎 Option 1: Prefix with underscore
def gdn_flops( total_seq_len: int, num_q_heads: int, - num_k_heads: int, + _num_k_heads: int, num_v_heads: int, head_size: int, - num_seqs: int, + _num_seqs: int, ) -> int:tests/gdn/test_prefill_delta_rule.py (1)
39-39: Unusedblock_sizeparameter.The
block_sizeparameter is passed to both_test_prefill_kerneland_test_chunked_prefillbut never used. Either use it in the test logic or remove it from the function signatures and parametrization.csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (1)
35-58: Register requirement calculation with magic numbers.The register calculation uses hardcoded values (40, 128, 64*1024, 248, 8). Consider adding comments explaining the rationale for these constants, particularly:
- Why
load_registersdiffers between debug and non-debug builds- The significance of 248 as the max register cap
tests/gdn/reference_delta_rule.py (1)
28-45: Clarify the asymmetric dtype return behavior.The
matmulfunction returns different dtypes depending on input:
- When either input is
bfloat16, it returnsfloat32(line 41)- When inputs are
float16, it returnsfloat16(line 43)This asymmetric behavior could be confusing. If this is intentional for numerical stability reasons, consider adding a comment explaining the rationale. Otherwise, consider making the return dtype consistent.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (29)
benchmarks/bench_gdn_prefill.pycsrc/flat/ampere/collective/flat_collective_inverse.hppcsrc/flat/ampere/collective/flat_collective_load.hppcsrc/flat/common.hppcsrc/flat/cute_ext.hppcsrc/flat/debug.hppcsrc/flat/hopper/collective/flat_collective_load.hppcsrc/flat/hopper/collective/flat_collective_store.hppcsrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hppcsrc/flat/hopper/collective/flat_common.hppcsrc/flat/hopper/collective/flat_named_barriers.hppcsrc/flat/hopper/device/device_universal.hppcsrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hppcsrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hppcsrc/flat/hopper/kernel/flat_options.hppcsrc/flat/hopper/kernel/flat_tile_scheduler.hppcsrc/flat/math.hppcsrc/flat/math_order_barrier.hppcsrc/flat/prefill/prefill_kernel.hppcsrc/flat/prefill/prefill_kernel_delta_rule_sm90.cucsrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuhcsrc/flat/type_traits.hppcsrc/flat/unused.hppcsrc/gdn_prefill_launcher.cuflashinfer/gdn_prefill.pyflashinfer/jit/gdn.pytests/gdn/conftest.pytests/gdn/reference_delta_rule.pytests/gdn/test_prefill_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (5)
- csrc/flat/unused.hpp
- tests/gdn/conftest.py
- csrc/flat/cute_ext.hpp
- csrc/flat/hopper/collective/flat_named_barriers.hpp
- csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
🧰 Additional context used
📓 Path-based instructions (4)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/gdn_prefill.pyflashinfer/jit/gdn.py
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/gdn_prefill_launcher.cucsrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/gdn/reference_delta_rule.pytests/gdn/test_prefill_delta_rule.py
flashinfer/jit/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/jit/**/*.py: JIT module generators inflashinfer/jit/must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Usegen_jit_spec()function to return a properly configured JitSpec from module generators with appropriatesourcesandextra_cuda_cflags
Specifysupported_major_versionsin JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Files:
flashinfer/jit/gdn.py
🧠 Learnings (12)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/gdn_prefill.pyflashinfer/jit/gdn.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
csrc/gdn_prefill_launcher.cu
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads
Applied to files:
csrc/gdn_prefill_launcher.cucsrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
csrc/flat/prefill/prefill_kernel.hppcsrc/flat/ampere/collective/flat_collective_inverse.hppcsrc/flat/common.hpp
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/flat/ampere/collective/flat_collective_inverse.hppcsrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hppcsrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
tests/gdn/test_prefill_delta_rule.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`
Applied to files:
flashinfer/jit/gdn.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Applied to files:
flashinfer/jit/gdn.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Applied to files:
flashinfer/jit/gdn.py
🧬 Code graph analysis (8)
flashinfer/gdn_prefill.py (3)
flashinfer/api_logging.py (1)
flashinfer_api(464-565)flashinfer/jit/gdn.py (1)
gen_gdn_prefill_sm90_module(25-37)csrc/gdn_prefill_launcher.cu (2)
gdn_prefill(71-170)gdn_prefill(71-73)
csrc/flat/math_order_barrier.hpp (1)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (8)
void(294-336)void(510-515)void(518-541)void(544-562)void(565-583)void(587-604)void(607-1000)void(1003-1230)
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (3)
csrc/flat/hopper/collective/flat_collective_store.hpp (2)
bool(180-192)params(150-150)csrc/flat/hopper/device/device_universal.hpp (4)
params(72-72)params(72-72)params(160-191)params(160-160)csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)
params(203-205)params(203-203)params(220-220)
benchmarks/bench_gdn_prefill.py (2)
flashinfer/gdn_prefill.py (2)
gdn_prefill(37-60)chunk_gated_delta_rule(81-202)flashinfer/testing/utils.py (1)
bench_gpu_time(1484-1631)
csrc/flat/hopper/device/device_universal.hpp (3)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (9)
args(499-501)args(499-499)can_implement(426-443)can_implement(426-426)params(510-510)initialize_workspace(504-508)initialize_workspace(504-506)to_underlying_arguments(446-497)to_underlying_arguments(446-447)csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (12)
args(189-191)args(189-189)args(193-197)args(193-194)args(199-201)args(199-199)args(212-218)args(212-212)params(203-205)params(203-203)params(220-220)block(207-210)csrc/flat/hopper/collective/flat_collective_store.hpp (5)
params(150-150)initialize_workspace(144-148)initialize_workspace(144-146)to_underlying_arguments(119-134)to_underlying_arguments(119-120)
csrc/flat/hopper/collective/flat_collective_store.hpp (3)
csrc/flat/cute_ext.hpp (2)
alignment_for_swizzle(48-50)alignment_for_swizzle(48-48)csrc/flat/ampere/collective/flat_collective_inverse.hpp (6)
void(85-115)void(119-171)void(180-189)void(192-275)void(278-360)void(363-476)include/flashinfer/attention/blackwell/kernel/fmha_tile_scheduler.hpp (1)
make_coord(94-97)
csrc/flat/ampere/collective/flat_collective_load.hpp (1)
csrc/flat/hopper/collective/flat_collective_load.hpp (2)
char(32-42)to_string(32-32)
flashinfer/jit/gdn.py (1)
flashinfer/jit/core.py (2)
JitSpec(216-397)gen_jit_spec(400-466)
🪛 Ruff (0.14.10)
flashinfer/gdn_prefill.py
91-91: Unused function argument: use_qk_l2norm_in_kernel
(ARG001)
benchmarks/bench_gdn_prefill.py
28-28: Unused function argument: num_k_heads
(ARG001)
31-31: Unused function argument: num_seqs
(ARG001)
tests/gdn/test_prefill_delta_rule.py
39-39: Unused function argument: block_size
(ARG001)
227-227: Unused function argument: block_size
(ARG001)
flashinfer/jit/gdn.py
35-35: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
🔇 Additional comments (42)
csrc/flat/ampere/collective/flat_collective_inverse.hpp (5)
27-72: LGTM! Well-structured helper functions.The
is_contiguoustemplate correctly handles both scalar and ranked layouts, and the SM80 namespace helpers properly convert fragment layouts and accumulator tensors for GEMM operations.
84-115: LGTM! Well-orchestrated multi-stage inverse.The staged approach (8×8 → 16×16 → 32×32 → 64×64) with proper barrier synchronization ensures correct dependencies between block inverse computations.
117-171: LGTM! Efficient warp-shuffle based inversion.The implementation uses warp shuffle operations for efficient thread communication during Gauss-Jordan elimination. The conditional handling of garbage-filled regions is well-structured, and the FP32 intermediate computation helps mitigate precision concerns.
192-360: LGTM! Correct implementation of block matrix inversion.Both
blockwise_diagonal_inversed_8x8_to_16x16andblockwise_diagonal_inversed_16x16_to_32x32correctly implement the block inverse formula documented at lines 173-177. The tensor partitioning, tiled MMA operations, and type conversions are well-structured.
362-476: LGTM! Complex cross-warp reduction correctly implemented.Despite the duplicate typedef issue (flagged separately), the warp-level parallelization and cross-warp reduction logic is correct. The careful buffer reuse with barriers (line 458) and the split x==0/x==1 write-reduce pattern ensure correct results.
csrc/flat/hopper/collective/flat_common.hpp (3)
1-166: AI summary describes a different file.The AI-generated summary claims this file:
- "Introduces three public macros: FLAT_UNUSED_PARAMETER(x), CHECK(expr, msg), CUDA_CHECK(expr)"
- "Includes debug.hpp and provides centralized error handling"
However, the actual file contains GEMM/GMMA utility functions in the
flat::collectivenamespace with no macros, nodebug.hppinclude, and no error-handling utilities. The summary appears to describe an entirely different file.
52-86: LGTM!The three
convert_to_gmma_rsoverloads correctly handle different MMA atom and TiledMMA types. The second overload appropriately defaults toMajor::Kfor both operands when major layout parameters are not provided.
88-93: Verify divisibility assumption in layout conversion.Line 90 performs integer division
size<0>(c) / size(a)without checking ifsize<0>(c)is divisible bysize(a). If this assumption is violated, the layout conversion will produce incorrect results silently.Consider:
- Adding a static_assert or runtime check to validate the divisibility
- Documenting the precondition that
size<0>(c)must be a multiple ofsize(a)csrc/flat/math_order_barrier.hpp (1)
27-97: Implementation looks solid.The ordered barrier coordination logic is well-documented with clear state-transition comments. The use of variadic templates for flexible WG-to-NB mapping is elegant.
csrc/flat/debug.hpp (1)
20-78: Debug infrastructure is well-designed.The conditional compilation approach ensures zero overhead when debugging is disabled. The variety of print macros (thread-filtered, warp-filtered, warp-group-filtered) provides flexible granularity for debugging CUDA kernels.
csrc/flat/common.hpp (1)
35-44: CUDA_CHECK macro looks correct.The macro properly captures the CUDA error, checks the result, formats an error message with error name and code, and throws a runtime exception. The implementation follows best practices for CUDA error handling.
csrc/flat/type_traits.hpp (2)
24-32: Type mapping utility looks correct.The
map_to_cutlasstrait provides a clean abstraction for mapping native CUDA types (half, nv_bfloat16) to their Cutlass equivalents. The specializations cover the expected type conversions.
34-49: first_non_void implementation is correct.The variadic template metafunction correctly:
- Guards against all-void parameter packs with a static_assert
- Returns the first non-void type via partial specialization
- Recursively skips void types using inheritance
This is a clean, idiomatic implementation of type selection at compile time.
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (1)
29-71: LGTM!The
WorkDescstruct is well-designed with proper compile-time branching for GQA vs GVA grouping semantics. The accessor methods provide a clean abstraction over the internal head index representation.csrc/flat/hopper/collective/flat_collective_store.hpp (3)
45-53: LGTM!The
smid()helper correctly retrieves the streaming multiprocessor ID using inline PTX assembly, with proper__CUDA_ARCH__guarding for host/device compilation.
194-231: LGTM!The
step()method correctly implements pipelined TMA stores with proper barrier synchronization. The tail-tile handling via modified tensormap is a good approach for handling variable-length sequences without padding.
233-257: LGTM!The tensormap creation for tail tiles is well-implemented. The parallel copy across warp lanes, followed by
__syncwarp()barriers and PTX fence operations, correctly ensures the modified descriptor is visible before use.csrc/flat/hopper/kernel/flat_options.hpp (2)
24-76: LGTM!The compile-time option system is well-designed using standard template metaprogramming patterns. The recursive
find_option_implcorrectly handles tag lookup with default fallback, andadd_option_tprovides clean option composition.
77-92: LGTM!The
Tagenum provides a comprehensive set of configuration options for the kernel builder. The inline comments explaining the purpose of each tag are helpful for understanding the kernel configuration semantics.csrc/flat/math.hpp (1)
24-34: LGTM!The
ceil_log2andnext_power_of_twoimplementations are correct for their intended use case. The recursiveceil_log2correctly computes the ceiling of log₂(n), yielding the smallest exponentesuch that2^e >= n.flashinfer/gdn_prefill.py (2)
30-77: LGTM on the module setup!The module correctly uses
@functools.cachefor caching (per coding guidelines) and properly declaresmutates_args=("output", "output_state")for the custom op registration.
178-184: LGTM!The comment correctly explains that output state must always be allocated because the kernel unconditionally writes to it. This is the right approach to ensure kernel correctness.
csrc/flat/prefill/prefill_kernel.hpp (1)
22-25: LGTM!The forward declaration of
cutlass::arch::Sm90is a good practice to avoid including full CUTLASS headers in this interface header.csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (2)
57-77: LGTM!The compile-time options construction using an IIFE with
decltypeis an elegant pattern for building the constexpr options tuple based on template parameters.
131-132: LGTM!The workspace allocation using
cutlass::device_memory::allocationfollows RAII semantics, ensuring proper cleanup when the function returns.csrc/gdn_prefill_launcher.cu (1)
71-170: Well-structured validation and kernel invocation.The
gdn_prefillfunction has comprehensive input validation including:
- GQA/GVA head count consistency checks
- Shape and dtype validation for all tensors
- Proper handling of optional tensors (input_state, alpha, beta)
- Default scale computation
The TVM-FFI integration follows the coding guidelines for framework bindings in
csrc/.csrc/flat/hopper/collective/flat_collective_load.hpp (1)
44-100: Well-designed TMA collective load abstraction.The
CollectiveLoadTmatemplate cleanly separates Q vs K/V loading patterns with appropriate tile shapes and layouts. The compile-time branching viaif constexprensures no runtime overhead.benchmarks/bench_gdn_prefill.py (1)
102-178: Well-structured benchmark function.The benchmark follows best practices:
- Proper input tensor setup with L2-normalized keys for numerical stability
- Warmup call before timing
- Uses
bench_gpu_timewith CUPTI for accurate GPU kernel timing- Calculates meaningful metrics (TFLOPs, TB/s)
csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu (2)
38-86: Exhaustive dispatch pattern for kernel configuration.The dispatch logic correctly covers all 16 combinations of the 4 boolean flags (is_gva, needs_beta, needs_alpha, init_state). The
#undef LAUNCHat line 86 properly cleans up the macro.The "unreachable" exceptions are correct dead code since the if-else chains are exhaustive.
89-101: Explicit template instantiations provided.Both FP16 (
half) and BF16 (nv_bfloat16) instantiations are provided, matching the dtype dispatch ingdn_prefill_launcher.tests/gdn/test_prefill_delta_rule.py (2)
32-131: Well-structured test helper with proper validation.The
_test_prefill_kernelfunction:
- Correctly skips unstable cases (no alpha and no beta)
- Properly seeds RNG for reproducibility
- L2-normalizes keys for numerical stability
- Uses appropriate tolerances per dtype
- Correctly transposes state to match reference implementation layout
220-377: Comprehensive chunked prefill test.The
_test_chunked_prefillfunction correctly:
- Tests state continuity across chunks by passing
our_state1to the second call- Concatenates variable-length sequences for reference comparison
- Uses the same tolerance strategy as the basic test
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (2)
405-499: Warp-specialized kernel execution pattern.The kernel correctly implements warp-group specialization with:
- LdSt warp group handling Q/K/V loads, O stores, and alpha/beta loads via separate warps
- Math0/Math1 warp groups performing state computation with barrier coordination
- MathA auxiliary warp group for preprocessing
The register allocation pattern (
reg_deallocfor LdSt/MathA,reg_allocfor Math0/Math1) balances registers between load/store and compute paths.
60-218: Well-organized kernel struct with clear abstractions.The
FlatKernelTmaWarpSpecializedDeltaRulestruct cleanly encapsulates:
- Pipeline types and state management
- Shared storage layout with proper alignment
- Problem shape and argument handling
- Static configuration (thread counts, register requirements)
csrc/flat/ampere/collective/flat_collective_load.hpp (3)
27-40: LGTM!The
LoadKindVectorenum andto_stringhelper follow the same pattern as the Hopper implementation and are correctly implemented.
42-108: LGTM!The
CollectiveLoadVectortemplate structure andpartition_SDmethod are well-implemented. The tensor partitioning logic correctly handles the setup for collective loads with appropriate tail masking.
110-157: LGTM!The
stepmethod correctly implements the pipelined collective load with proper tail handling, memory fencing, and optional vector processing. The implementation follows CUDA best practices for shared memory synchronization.tests/gdn/reference_delta_rule.py (5)
21-25: LGTM!The
exclusive_cumsumimplementation is straightforward and correct for computing sequence offsets.
48-74: LGTM!Both
LambdaQandLambdaKcorrectly implement decay factor computation with appropriate handling for partial blocks. The comment on line 61 helpfully highlights the crucial difference for tail blocks.
78-123: LGTM!The
_linear_attentionfunction correctly implements causal linear attention with optional decay factors and custom QK weights.
248-324: LGTM!The
delta_rulefunction correctly implements the delta rule algorithm with proper handling for GQA and GVA configurations. The state update logic (lines 312-317) follows the mathematical formulation described in the comment.
348-493: LGTM!The
blockwise_delta_rulefunction implements a complex blockwise delta rule algorithm with proper handling for:
- GQA/GVA configurations (lines 376-383)
- Tail block padding with appropriate defaults (lines 425-430: alpha=1.0, beta=0.0)
- Block-level state updates with decay factors
- Inter-block and intra-block attention contributions
The implementation is mathematically sound and handles edge cases appropriately.
| /* | ||
| * Copyright (c) 2025 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include "cute/tensor.hpp" | ||
| #include "cutlass/cutlass.h" | ||
| #include "cutlass/pipeline/sm90_pipeline.hpp" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if there's a corresponding hopper directory and whether this file should be there
echo "=== Checking for hopper collective directories ==="
fd -t d "hopper.*collective" csrc/
echo -e "\n=== Checking for other files in ampere that use sm90_pipeline ==="
rg -n "sm90_pipeline" csrc/flat/ampere/ 2>/dev/null || echo "No matches found"
echo -e "\n=== List all ampere directories ==="
fd -t d "ampere" csrc/
echo -e "\n=== Check hopper directory structure ==="
fd -t d "hopper" csrc/ | head -20Repository: flashinfer-ai/flashinfer
Length of output: 428
🏁 Script executed:
#!/bin/bash
# Check what's inside hopper directory
echo "=== Contents of csrc/flat/hopper/ ==="
find csrc/flat/hopper/ -type f -o -type d | sort
echo -e "\n=== Search for sm90_pipeline usage across entire codebase ==="
rg -l "sm90_pipeline" csrc/ --type cpp --type hpp | head -20
echo -e "\n=== Check if there's a hopper/collective subdirectory ==="
ls -la csrc/flat/hopper/ 2>/dev/null
</script>Repository: flashinfer-ai/flashinfer
Length of output: 1295
🏁 Script executed:
#!/bin/bash
# Compare the two flat_collective_load.hpp files
echo "=== Ampere version (first 50 lines) ==="
head -50 csrc/flat/ampere/collective/flat_collective_load.hpp
echo -e "\n=== Hopper version (first 50 lines) ==="
head -50 csrc/flat/hopper/collective/flat_collective_load.hpp
echo -e "\n=== Check if files are identical ==="
if cmp -s csrc/flat/ampere/collective/flat_collective_load.hpp csrc/flat/hopper/collective/flat_collective_load.hpp; then
echo "Files are IDENTICAL"
else
echo "Files are DIFFERENT"
echo "Ampere file size: $(wc -c < csrc/flat/ampere/collective/flat_collective_load.hpp)"
echo "Hopper file size: $(wc -c < csrc/flat/hopper/collective/flat_collective_load.hpp)"
fiRepository: flashinfer-ai/flashinfer
Length of output: 3301
🏁 Script executed:
#!/bin/bash
# Check all files in ampere/collective and their sm90_pipeline usage
echo "=== All files in ampere/collective ==="
ls -1 csrc/flat/ampere/collective/
echo -e "\n=== Check sm90_pipeline references in ampere/collective ==="
grep -n "sm90_pipeline" csrc/flat/ampere/collective/*.hpp
echo -e "\n=== Check sm90_pipeline references in hopper/collective ==="
grep -n "sm90_pipeline" csrc/flat/hopper/collective/*.hpp
echo -e "\n=== Check if other hopper files use sm90_pipeline ==="
grep -r "sm90_pipeline" csrc/flat/hopper/ --include="*.hpp"Repository: flashinfer-ai/flashinfer
Length of output: 639
🏁 Script executed:
#!/bin/bash
# Read the full ampere collective_load file to check for architectural context
echo "=== Full ampere/collective/flat_collective_load.hpp ==="
cat csrc/flat/ampere/collective/flat_collective_load.hpp
echo -e "\n=== Check for any build guards or architecture checks ==="
grep -i "sm80\|ampere\|hopper\|sm90\|arch" csrc/flat/ampere/collective/flat_collective_load.hppRepository: flashinfer-ai/flashinfer
Length of output: 6182
Remove this file from csrc/flat/ampere/collective/ or add architecture guards.
This file in the ampere (SM80) directory imports cutlass/pipeline/sm90_pipeline.hpp on line 20 without any architecture guards. The SM90 pipeline is specific to Hopper (SM90+) architecture and should not be used in Ampere code without proper conditional compilation.
Either:
- Move this to the hopper directory (the correct implementation already exists at
csrc/flat/hopper/collective/flat_collective_load.hpp), or - Replace the SM90 pipeline with an architecture-appropriate pipeline for Ampere and add proper
#ifdefguards if the ampere implementation is intentional.
🤖 Prompt for AI Agents
In csrc/flat/ampere/collective/flat_collective_load.hpp lines 1-20, this Ampere
directory file includes the Hopper-specific header
cutlass/pipeline/sm90_pipeline.hpp; remove or correct it. Either delete this
file from csrc/flat/ampere/collective/ if the implementation belongs in hopper,
or replace the sm90 include with the appropriate Ampere pipeline header and add
conditional compilation guards (e.g. #if defined(CUTLASS_SM80) ||
defined(CUTLASS_SM_80) / #elif defined(CUTLASS_SM90) ... / #endif or use an
equivalent project macro) so the SM90 pipeline is only included when building
for SM90/Hopper; alternatively move the file to csrc/flat/hopper/collective/ to
match the existing hopper implementation.
| #define DPRINTF_W(fmt, ...) \ | ||
| if (IS_PRINT_BLOCK) \ | ||
| printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ | ||
| threadIdx.x, ##__VA_ARGS__) | ||
| #define DPRINTF0_W(fmt, ...) \ | ||
| if (IS_PRINT_BLOCK && threadIdx.x % 32 == 0) \ | ||
| printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ | ||
| threadIdx.x, ##__VA_ARGS__) | ||
| #define DPRINTF_WG(fmt, ...) \ | ||
| if (IS_PRINT_BLOCK) \ | ||
| printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ | ||
| threadIdx.x, ##__VA_ARGS__) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n csrc/flat/debug.hpp | sed -n '30,55p'Repository: flashinfer-ai/flashinfer
Length of output: 1951
DPRINTF_W and DPRINTF_WG are identical and should differ.
Both macros have identical implementations despite their distinct naming. The "0" variants clearly establish a pattern where DPRINTF0_W filters at warp level (threadIdx.x % 32 == 0) and DPRINTF0_WG filters at warp-group level (threadIdx.x % 128 == 0). The non-"0" variants should follow the same convention: consider whether DPRINTF_WG should filter on warp-group boundaries (e.g., threadIdx.x % 128 == 0) or whether the current unconditional behavior is intentional.
🤖 Prompt for AI Agents
In csrc/flat/debug.hpp around lines 37 to 48, DPRINTF_W and DPRINTF_WG are
identical but should differ: keep DPRINTF_W as the current per-thread print
(only gated by IS_PRINT_BLOCK) and change DPRINTF_WG to only print on warp-group
boundaries by adding a condition like threadIdx.x % 128 == 0 in its if check
(i.e., if (IS_PRINT_BLOCK && threadIdx.x % 128 == 0)); keep the same printf
format and arguments to maintain consistent log output.
| #pragma once | ||
|
|
||
| #include "cute/tensor.hpp" | ||
| #include "cutlass/kernel_hardware_info.h" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if kernel_hardware_info is used in this file
echo "=== Checking for kernel_hardware_info usage in flat_common.hpp ==="
rg -n 'kernel_hardware_info|KernelHardwareInfo' csrc/flat/hopper/collective/flat_common.hpp
echo ""
echo "=== Checking the file structure and includes ==="
head -30 csrc/flat/hopper/collective/flat_common.hpp
echo ""
echo "=== Checking for files that include flat_common.hpp ==="
rg -l 'flat_common\.hpp' --type=cpp --type=cu csrc/ | head -10Repository: flashinfer-ai/flashinfer
Length of output: 1417
🏁 Script executed:
#!/bin/bash
# Check the entire file content
echo "=== Full content of flat_common.hpp ==="
cat csrc/flat/hopper/collective/flat_common.hpp
echo ""
echo "=== Files including flat_common.hpp ==="
rg -l 'flat_common\.hpp' csrc/ | head -20
echo ""
echo "=== Checking if kernel_hardware_info is used in files that include flat_common.hpp ==="
rg -l 'kernel_hardware_info' csrc/flat/hopper/collective/ | grep -v flat_common.hppRepository: flashinfer-ai/flashinfer
Length of output: 6884
Remove unused include.
The cutlass/kernel_hardware_info.h header is included but not used anywhere in this file. Remove it.
🤖 Prompt for AI Agents
In csrc/flat/hopper/collective/flat_common.hpp around line 19, remove the unused
include directive for "cutlass/kernel_hardware_info.h" — delete that include
line and ensure no other symbols from that header are referenced elsewhere in
this file; run a local build to confirm no missing dependencies after removal.
| if constexpr (sizeof(Element) == 1) { | ||
| // 00 11 22 33 00 11 22 33 acc layout | ||
| // 00 00 11 11 22 22 33 33 operand layout | ||
| // BB AA AA BB AA BB BB AA conflict-free exchange pattern | ||
| // 16-bit exchange; so process two at a time potentially | ||
| int tid = threadIdx.x % 4; | ||
| auto values_u32 = recast<uint32_t>(operand); | ||
|
|
||
| CUTE_UNROLL | ||
| for (int n = 0; n < size<1>(values_u32); n++) { | ||
| CUTE_UNROLL | ||
| for (int k = 0; k < size<2>(values_u32); k++) { | ||
| CUTE_UNROLL | ||
| for (int ii = 0; ii < 8; ii += 4) { | ||
| uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k); | ||
| uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k); | ||
|
|
||
| // step A: | ||
| // t 1 v 0 -> t 0 v 1 | ||
| // t 2 v 0 -> t 1 v 0 | ||
| // t 0 v 1 -> t 2 v 0 | ||
| // t 3 v 1 -> t 3 v 1 | ||
|
|
||
| int v_to_send = tid == 1 || tid == 2 ? 0 : 1; | ||
| int v_to_recv = v_to_send; | ||
| int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF; | ||
|
|
||
| uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1; | ||
|
|
||
| values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4); | ||
|
|
||
| // step B: | ||
| // t 0 v 0 -> t 0 v 0 | ||
| // t 3 v 0 -> t 1 v 1 | ||
| // t 1 v 1 -> t 2 v 1 | ||
| // t 2 v 1 -> t 3 v 0 | ||
|
|
||
| v_to_send = 1 - v_to_send; | ||
| v_to_recv = 1 - v_to_recv; | ||
| t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF; | ||
|
|
||
| uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1; | ||
|
|
||
| values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4); | ||
|
|
||
| values_u32(ii / 2 + 0, n, k) = | ||
| __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410); | ||
| values_u32(ii / 2 + 1, n, k) = | ||
| __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632); | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Document the purpose and assumptions of the 1-byte element shuffle logic.
This code implements a complex warp-shuffle-based data exchange for 1-byte elements, but lacks critical documentation:
-
Why is this exchange needed? The comments describe the exchange pattern but not the underlying problem (e.g., avoiding bank conflicts, matching expected layout, etc.)
-
Thread grouping assumption: Line 114 uses
threadIdx.x % 4, assuming threads are processed in groups of 4. This assumption should be documented. -
Magic constants: The hex values
0x3021,0x2130,0x1054,0x5410,0x3276,0x7632encode specific patterns. Consider:- Adding named constants with explanatory comments
- Documenting the bit layout of these values
- Explaining the relationship between the values and the exchange pattern
-
Layout transformation: Document the input and output layouts explicitly (e.g., "transforms from ACC layout
00 11 22 33 00 11 22 33to operand layout00 00 11 11 22 22 33 33")
🤖 Prompt for AI Agents
In csrc/flat/hopper/collective/flat_common.hpp around lines 109-161, the 1-byte
element warp-shuffle exchange logic lacks documentation for intent, assumptions
and magic constants; update the code by adding a brief header comment that
states the purpose (why this exchange is needed and the layout transformation
performed, e.g., transforms ACC layout "00 11 22 33 00 11 22 33" to operand
layout "00 00 11 11 22 22 33 33"), explicitly document the thread grouping
assumption (that threadIdx.x is grouped modulo 4 and why), and replace inline
hex magic numbers with named constants and short comments explaining their
bit/byte encoding and how they map to the described exchange steps (A/B) and
shuffle masks so future readers can reason about the pattern quickly.
| int lane_predicate = cute::elect_one_sync(); | ||
| uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); |
There was a problem hiding this comment.
Unused variable block_rank_in_cluster.
block_rank_in_cluster is computed but never used in the function. Remove it or add a comment if it's reserved for future cluster support.
🔎 Suggested fix
int lane_predicate = cute::elect_one_sync();
- uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int lane_predicate = cute::elect_one_sync(); | |
| uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); | |
| int lane_predicate = cute::elect_one_sync(); |
🤖 Prompt for AI Agents
In csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp around
lines 249-250, the variable block_rank_in_cluster is computed but never used;
either remove the declaration entirely or explicitly mark it as intentionally
unused (e.g., add a comment "reserved for future cluster support" and a
(void)block_rank_in_cluster; or use an attribute/pragma to suppress unused
warnings) so the compiler doesn’t emit an unused-variable warning and intent is
clear.
| def gen_gdn_prefill_sm90_module() -> JitSpec: | ||
| return gen_jit_spec( | ||
| name="gdn_prefill_launcher", | ||
| sources=[ | ||
| jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu", | ||
| jit_env.FLASHINFER_CSRC_DIR | ||
| / "flat" | ||
| / "prefill" | ||
| / "prefill_kernel_delta_rule_sm90.cu", | ||
| ], | ||
| extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"], | ||
| extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR], | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Specify supported GPU architectures for the SM90 module.
The JitSpec should include supported_major_versions to restrict compilation to SM90 and newer architectures. This prevents attempting to compile SM90-specific code on incompatible GPUs.
🔎 Proposed fix
def gen_gdn_prefill_sm90_module() -> JitSpec:
return gen_jit_spec(
name="gdn_prefill_launcher",
sources=[
jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "flat"
/ "prefill"
/ "prefill_kernel_delta_rule_sm90.cu",
],
extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
+ supported_major_versions=[9, 10, 11, 12],
)Based on coding guidelines for flashinfer/jit/**/*.py.
🧰 Tools
🪛 Ruff (0.14.10)
35-35: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
🤖 Prompt for AI Agents
In flashinfer/jit/gdn.py around lines 25 to 37, the JitSpec for the SM90 module
is missing a supported_major_versions constraint; update the gen_jit_spec call
to include supported_major_versions set to only SM90 (e.g.
supported_major_versions=[90]) so the launcher is only compiled on SM90-or-newer
GPUs, and place this parameter alongside the existing extra_cuda_cflags and
extra_include_paths arguments.
| k = k.repeat_interleave(num_qo_heads // num_kv_heads, dim=1) | ||
| v = v.repeat_interleave(num_qo_heads // num_kv_heads, dim=1) | ||
|
|
||
| KVs = [] # FIXME: kernel debug only |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Address the FIXME: Remove or make debug output optional.
Line 154 has a FIXME comment indicating KVs is for kernel debugging only. This debug list accumulates state for every block (line 238) and is returned as part of the function output (line 245), which could cause memory issues for long sequences.
Consider either:
- Removing the debug code if no longer needed, or
- Making it optional via a parameter (e.g.,
return_debug_info=False)
🔎 Proposed fix to make debug output optional
@torch.inference_mode
def blockwise_linear_attention(
q: torch.Tensor, # [total_seq_len, num_qo_heads, head_size]
k: torch.Tensor, # [total_seq_len, num_kv_heads, head_size]
v: torch.Tensor, # [total_seq_len, num_kv_heads, head_size]
seq_lens: list[int], # sequence length for each sequence
block_size: int = 32,
scale_factor=1.0,
decay_factor: float
| torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads
decay_exponent_offset=0,
kv_dtype: torch.dtype = torch.float32,
+ return_debug_info: bool = False,
) -> torch.Tensor:
num_qo_heads = q.size(1)
head_size = q.size(2)
num_kv_heads = k.size(1)
if scale_factor != 1.0:
k = k * scale_factor
if isinstance(decay_factor, float):
decay_factor = torch.ones(num_qo_heads) * decay_factor
decay_factor = decay_factor.to(q.device)
assert decay_factor.numel() == num_qo_heads
decay_factor = decay_factor.reshape(num_qo_heads, 1, 1)
k = k.repeat_interleave(num_qo_heads // num_kv_heads, dim=1)
v = v.repeat_interleave(num_qo_heads // num_kv_heads, dim=1)
- KVs = [] # FIXME: kernel debug only
+ KVs = [] if return_debug_info else None
kv = torch.zeros(
(len(seq_lens), num_qo_heads, head_size, head_size),
dtype=kv_dtype,
device=q.device,
)
output = torch.zeros_like(q)
# ... rest of the function ...
- KVs.append(carried_kv.clone())
+ if return_debug_info:
+ KVs.append(carried_kv.clone())
blk_offset += block_size
# print(kv.shape, carried_kv.shape)
kv[seq_idx, :, :] = carried_kv
- return output, kv, KVs
+ if return_debug_info:
+ return output, kv, KVs
+ else:
+ return output, kvCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tests/gdn/reference_delta_rule.py around line 154 (and referencing behavior
at lines 238 and 245), the KVs list is a persistent debug accumulator marked
FIXME and can grow unbounded; make debug output optional by adding a new
parameter (e.g., return_debug_info=False) to the function signature, initialize
and append to KVs only when return_debug_info is True, and conditionally include
KVs in the function return (when True return the existing result plus debug
data, otherwise return the original result unchanged); update any call
sites/tests to use the new parameter as needed.
| def identity_add_strict_lower_diagonal(m: torch.Tensor): | ||
| SIZE = m.size(-1) | ||
| assert m.size(-2) == SIZE | ||
| with torch.device(m.device): | ||
| m = m.clone() | ||
| mask = torch.arange(SIZE).unsqueeze(1) <= torch.arange(SIZE) | ||
| m[:, mask] = 0.0 | ||
| # m[mask.unsqueeze(0)] = 0.0 | ||
| m = m + torch.eye(SIZE).unsqueeze(0) | ||
| return m |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Fix the misleading function name.
The function identity_add_strict_lower_diagonal actually zeros the upper triangle (not the lower). Line 332 creates a mask where row <= col, which selects the upper triangular elements including the diagonal. Line 333 then zeros these elements.
The function effectively computes: strict_lower_triangle(m) + identity, so the name is backwards from the implementation.
Consider renaming to identity_add_strict_lower_triangle or keep_strict_lower_plus_identity for clarity.
🔎 Proposed fix
-def identity_add_strict_lower_diagonal(m: torch.Tensor):
+def identity_add_strict_lower_triangle(m: torch.Tensor):
SIZE = m.size(-1)
assert m.size(-2) == SIZE
with torch.device(m.device):
m = m.clone()
- mask = torch.arange(SIZE).unsqueeze(1) <= torch.arange(SIZE)
+ # Keep only strict lower triangle (row > col), then add identity
+ mask = torch.arange(SIZE).unsqueeze(1) <= torch.arange(SIZE) # upper triangle including diagonal
m[:, mask] = 0.0
- # m[mask.unsqueeze(0)] = 0.0
m = m + torch.eye(SIZE).unsqueeze(0)
return mThen update the caller on line 446:
- IKK = identity_add_strict_lower_diagonal(
+ IKK = identity_add_strict_lower_triangle(
beta_HS1 * torch.exp(Gamma_HSS) * matmul(k_HSK, k_HSK.transpose(-2, -1))
) # NOTE: beta scale row-wiseCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tests/gdn/reference_delta_rule.py around lines 327 to 336 the function name
identity_add_strict_lower_diagonal is misleading because the implementation
zeros the upper triangle (mask row <= col) and then adds the identity, producing
strict lower triangle + identity; rename the function to
identity_add_strict_lower_triangle (or keep_strict_lower_plus_identity) and
update its single caller at line 446 to use the new name; ensure any
import/usage matches the new identifier and run tests to verify no other
references remain.
|
@guangyunh-nv Is there blackwell sm100 support in the plan? 🎉 |
|
There is, but not by me. ETA is unclear tho. |
|
Do we have plans to support bfloat16 for parameters like input_state, g, and beta? Currently, they seem limited to float32, but in vllm they are bf16. Is this a deliberate choice? flashinfer/csrc/gdn_prefill_launcher.cu Lines 134 to 136 in e4dee98 flashinfer/csrc/gdn_prefill_launcher.cu Line 146 in e4dee98 flashinfer/csrc/gdn_prefill_launcher.cu Line 102 in e4dee98 |
|
Hi @ZJY0516 I don't think that's a hard constraint, we can make it more flexible. |
|
It is minor change. The problem is I am not sure if it is wise to store state in bf16. For prefilling, the internal state are all in FP32. This will make the long range accumulation much more accurate. But for decoding, every recurrent step accumluate on BF16 might impose a severe roundoff error. I think this should be verified from algorithm side first. |
…ing. (#2422) <!-- .github/pull_request_template.md --> ## 📌 Description This PR implements these features: 1. accelerate hopper's gdn prefill compilation time by split compilation 2. fix the docstring of gdn prefill kernel, instead of [N, H, K, V], it expects [N, H, V, K] ## 🔍 Related Issues #2276 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @guangyunh-nv <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Enhanced JIT module generation for GDN prefill kernels with template-driven compilation and separate kernel instantiation. * **Improvements** * JIT specification now intelligently handles C++ standard flags, applying defaults only when not already specified. * **Documentation** * Clarified final state memory layout description for GDN prefill operations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
…ing. (flashinfer-ai#2422) <!-- .github/pull_request_template.md --> ## 📌 Description This PR implements these features: 1. accelerate hopper's gdn prefill compilation time by split compilation 2. fix the docstring of gdn prefill kernel, instead of [N, H, K, V], it expects [N, H, V, K] ## 🔍 Related Issues flashinfer-ai#2276 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @guangyunh-nv <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Enhanced JIT module generation for GDN prefill kernels with template-driven compilation and separate kernel instantiation. * **Improvements** * JIT specification now intelligently handles C++ standard flags, applying defaults only when not already specified. * **Documentation** * Clarified final state memory layout description for GDN prefill operations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
📌 Description
This PR adds implementation for Gated Delta Rule (or Gated Delta Net) on Hopper architecture to better support Qwen-next like architecture.
🔍 Related Issues
#1690
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Thanks @jiahanc for initiating the kernel integration and implementing the API.
Summary by CodeRabbit
New Features
Benchmarks & Tests
✏️ Tip: You can customize this high-level summary in your review settings.