-
Notifications
You must be signed in to change notification settings - Fork 584
refactor: update fa3 codebase and fix hopper unittest [part 1] #2111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds page-level K/V stride and page_size fields across host and device params, enforces K/V stride and stride_n consistency at runtime in paged paths, removes block-sparse→vector-sparse conversion and its helpers/tests, refactors mainloops to use page-table/manual page-based K/V loads, and adds FP8 ragged/paged variants and tests. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python tests/bench
participant Wrap as Wrapper (flashinfer/prefill.py / sparse.py)
participant Host as Host planner (csrc/*, headers)
participant Kernel as CUDA device mainloop
Note over Py,Kernel: OLD — vector-sparse flow (removed)
Py->>Wrap: reset_workspace_buffer(..., vector_sparse_buffers)
Wrap->>Host: plan/run(..., vector_sparse_buffers)
Host->>Kernel: invoke with BlockSparseIndexedGather tensors
Kernel->>Kernel: block-sparse gather loads
Note over Py,Kernel: NEW — paged/ragged + FP8
Py->>Wrap: reset_workspace_buffer(..., paged_kv_indptr, paged_kv_indices, o_data_type?, fp8_scales?)
Wrap->>Host: plan/run(..., paged_kv_indptr, paged_kv_indices, o_data_type, fp8_scales)
Host->>Kernel: invoke with (K_ptr, V_ptr, kv_indices, k_page_stride, v_page_stride, k_stride_n, v_stride_n, page_size, ...)
Kernel->>Kernel: compute page_idx via divmod(page_size)
Kernel->>Kernel: load_kv_tile / cp_async_zfill using page strides and stride_n
Kernel->>Kernel: proceed to MMA/epilogue (sync)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the efficiency and correctness of the Flash Attention v3 (FA3) implementation, particularly for paged Key-Value (KV) caches with page sizes greater than one. By integrating page offset calculations directly into the kernel and optimizing KV offset handling with prefetching and shuffling, the codebase becomes more streamlined and performant. A critical bug affecting Hopper unittests has also been resolved, ensuring robust operation on the target architecture. These changes collectively contribute to a more optimized and reliable sparse attention mechanism. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant refactoring of the FA3 codebase. It removes the standalone block_sparse_indices_to_vector_sparse_offsets function and moves the page offset calculation directly into the CUDA kernel, which is a great simplification. The changes also include an optimization for kv_offset calculation using prefetching and shuffling, which should improve performance. The code removal across C++, Python, and header files is consistent and clean. I've found a couple of minor areas for code improvement to reduce redundancy, but overall the changes look solid and well-implemented.
| int d_idx = get<1>(coord); | ||
| int kv_idx = kv_base_idx + kv_offset; | ||
|
|
||
| bool guard = kv_idx < kv_len && kv_offset < valid_tile_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The guard condition can be simplified. The check kv_idx < kv_len is redundant when use_predicate is true, as it's already implied by kv_offset < valid_tile_size. When use_predicate is false, valid_tile_size is CTA_KV, and kv_offset is always less than CTA_KV, so the guard is not needed for non-last tiles anyway. You can simplify this to just kv_offset < valid_tile_size.
bool guard = kv_offset < valid_tile_size;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
flashinfer/prefill.py (1)
2109-2156: Paged KV run argument rewiring is reasonable; verify trtllm cum_seq_lens_kv semanticsUsing
_paged_kv_indptr_buf/_paged_kv_indices_bufdirectly inrun_argskeeps the Python wrapper aligned with the new paged-KV FFI signature, and_qo_indptr_bufis a natural fit forcum_seq_lens_q. The only subtle point is that_paged_kv_indptr_bufis in units of pages, while trtllm paged attention APIs traditionally expectcum_seq_lens_kvin tokens; if the trtllm-gen backend actually consumes those trailing args as cum-token lengths, it may needcumsum(seq_lens)instead of raw page indptr. Worth double-checking against the current trtllm kernel contract.tests/attention/test_batch_prefill_kernels.py (1)
147-157: Good coverage of preallocated LSE path; consider also checking LSE valuesUsing
lse_buffer = torch.empty_like(lse)and rerunning without=o_buffer, lse=lse_buffernow exercises the buffered LSE write path, which should catch the Hopper regression. To fully validate it, you may also want to asserttorch.testing.assert_close(lse, lse_buffer, ...)alongside the existingovso_buffercheck.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
csrc/batch_prefill_sm90.cu(1 hunks)csrc/batch_prefill_sm90_customize_config.jinja(1 hunks)csrc/flashinfer_page_binding.cu(0 hunks)csrc/page.cu(0 hunks)flashinfer/page.py(0 hunks)flashinfer/prefill.py(3 hunks)flashinfer/sparse.py(2 hunks)include/flashinfer/attention/hopper/default_params.cuh(1 hunks)include/flashinfer/attention/hopper/prefill_sm90.cuh(1 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(8 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(1 hunks)include/flashinfer/attention/hopper/sparse_mainloop.cuh(8 hunks)include/flashinfer/page.cuh(0 hunks)tests/attention/test_batch_prefill_kernels.py(1 hunks)tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py(0 hunks)
💤 Files with no reviewable changes (5)
- csrc/flashinfer_page_binding.cu
- csrc/page.cu
- flashinfer/page.py
- tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py
- include/flashinfer/page.cuh
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/batch_prefill_sm90.cu (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
TVM_FFI_ICHECK_EQ(167-171)TVM_FFI_ICHECK_EQ(283-286)
tests/attention/test_batch_prefill_kernels.py (1)
flashinfer/prefill.py (6)
run(1924-1936)run(1939-1951)run(1953-2166)run(2768-2778)run(2781-2791)run(2793-2939)
flashinfer/prefill.py (1)
flashinfer/page.py (1)
get_seq_lens(176-199)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)
102-108: LGTM: Correct handling of separate K and V strides.This implementation correctly supports different memory layouts for K and V:
- Parameterized design: The
load_kv_tilelambda (lines 232-267) acceptsstride_nandpage_strideas parameters rather than hardcoding them- Separate calls: K and V loads pass their respective strides:
- K:
load_kv_tile(k_base_ptr, k_stride_n, k_page_stride, ...)(line 275)- V:
load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, ...)(line 298)- Flexible addressing: Line 259 computes offsets using the passed-in parameters
This is the correct pattern for page-based sparse loading and avoids the stride assumption issue present in
sparse_mainloop.cuh.Also applies to: 118-124, 232-267
include/flashinfer/attention/hopper/sparse_mainloop.cuh (1)
110-112: Stride equality is already validated on the host side; v_page_stride is intentionally passed through for API consistency.The
v_page_strideparameter, while unused in the non-quantizedsparse_mainloop.cuhkernel, is not a bug. An assertion incsrc/batch_prefill_sm90.culine 235 validates that K and V page strides are equal at runtime, and the comment in the sparse mainloop (line 281) explicitly documents this assumption. Theprefetch_kv_offsetlambda correctly reuses the same offset computation for both K and V loads.The parameter exists for API consistency with the quantized variant (
mainloop_sparse_load.cuh), which does usev_page_strideseparately. If API unification across quantized and non-quantized paths is intentional, no action is needed.include/flashinfer/attention/hopper/default_params.cuh (1)
157-160: k_page_stride / v_page_stride fields look consistentAdding explicit page-stride fields after
nnz_qomatches the other Hopper paged params structs and keeps types/ordering coherent with the new sparse mainloop arguments.include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)
337-344: FP8 sparse mainloop argument rewiring looks correctSwitching K/V from
get_gmem_layoutto explicit{k_stride_n, k_page_stride, v_stride_n, v_page_stride, kv_indices, page_size}matches the updated sparse mainloop API and keeps Q/O layout handling unchanged.flashinfer/prefill.py (1)
36-36: Importingget_seq_lensis appropriateThis import matches the later use of
get_seq_lensinBatchPrefillWithPagedKVCacheWrapper.planto derive KV sequence lengths from paged metadata.include/flashinfer/attention/hopper/prefill_sm90.cuh (1)
382-386: Sparse prefill mainloop now correctly receives KV indices and paging metadataPassing
kv_indices,window_left,k_page_stride,v_page_stride, andpage_sizeintoSparseCollectiveMainloop::to_underlying_argumentslines up with the new paged sparse mainloop contract and keeps Q/K/V layouts unchanged.csrc/batch_prefill_sm90_customize_config.jinja (1)
107-111: PagedParams gains explicit K/V page strides in the right placeAdding
k_page_stride/v_page_strideafternnz_qokeeps this JIT-generated PagedParams struct aligned with the Hopper default params and with how batch_prefill_sm90.cu now fills these fields frompaged_{k,v}_cache.stride(0).csrc/batch_prefill_sm90.cu (1)
221-238: Page-stride wiring and K/V stride consistency checks make senseRecording
k_page_stride/v_page_stridefromstride(0)in both layouts and then asserting that K/V share the same page stride andstride_nis a good guardrail for the sparse paged mainloop; it will surface mis-laid-out KV caches early with clear error messages rather than letting the kernel access mismatched layouts.
| int64_t my_kv_offset[2]; // Rolling buffer: page_idx * page_stride + entry_idx * stride_n | ||
|
|
||
| // Group organization based on partition strategy | ||
| constexpr int NUM_KV_PER_ITER = decltype(size<1>(tKcK))::value; // e.g., 12 | ||
| constexpr int KV_STRIDE = CTA_KV / NUM_KV_PER_ITER; // 96/12 = 8 | ||
| constexpr int NUM_GROUPS = KV_STRIDE; // 8 groups (one per lane) | ||
| constexpr int THREADS_PER_GROUP = NUM_COPY_THREADS / NUM_GROUPS; // 128/8 = 16 | ||
| constexpr int NUM_ITERS_PER_GROUP = NUM_KV_PER_ITER; // 12 iterations per group | ||
|
|
||
| int group_id = thread_idx / THREADS_PER_GROUP; // 0-7 | ||
| int thread_in_group = thread_idx % THREADS_PER_GROUP; // 0-15 | ||
|
|
||
| // Prefetch: compute page_idx * page_stride + entry_idx * stride_n | ||
| // NOTE: Assumes K and V have same strides (asserted on host side) | ||
| auto prefetch_kv_offset = [&](int kv_tile_idx, bool use_predicate) { | ||
| int kv_base_idx = kv_tile_idx * CTA_KV; | ||
| int buf_idx = kv_tile_idx % 2; | ||
|
|
||
| int kv_idx_read = kv_base_idx + group_id + thread_in_group * KV_STRIDE; | ||
| bool valid_read = | ||
| thread_in_group < NUM_ITERS_PER_GROUP && (!use_predicate || kv_idx_read < kv_len); | ||
|
|
||
| if (valid_read) { | ||
| // Use divmod to find page and offset within page | ||
| uint32_t page_iter, entry_idx; | ||
| mainloop_params.page_size.divmod(kv_idx_read, page_iter, entry_idx); | ||
| IdType page_idx = kv_indices_ptr[page_iter]; | ||
| // Pre-compute: page_idx * page_stride + entry_idx * stride_n | ||
| my_kv_offset[buf_idx] = page_idx * k_page_stride + entry_idx * k_stride_n; | ||
| } else { | ||
| my_kv_offset[buf_idx] = 0; | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefetch logic assumes K and V have identical strides.
The prefetch_kv_offset lambda computes my_kv_offset using only K strides (k_page_stride and k_stride_n on line 296), but this offset is later reused for both K and V loads in load_kv_with_gather. This hardcodes the assumption that K and V have identical memory layouts.
Compare with mainloop_sparse_load.cuh (lines 232-267), which correctly uses separate stride parameters in its load_kv_tile lambda, allowing K and V to have different layouts.
Consider refactoring to either:
- Option 1: Compute separate offsets for K and V if they can differ
- Option 2: Use a single set of stride parameters if layouts must be identical
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 268-300,
the prefetch lambda computes my_kv_offset using only K strides but the same
offset is later used for both K and V loads, incorrectly assuming identical K/V
layouts; fix by computing distinct offsets for K and V (or enforce
identical-layout at compile/runtime). Update the lambda to accept/use separate
stride parameters (e.g., k_page_stride/k_stride_n and v_page_stride/v_stride_n)
and write into two rolling buffers (my_kv_offset_k[2] and my_kv_offset_v[2]) so
load_kv_with_gather can use the correct offset for each tensor, or alternatively
add a clear static_assert/runtime check and comment that K and V must share
strides and keep single offset.
| auto load_kv_with_gather = [&](auto&& tXsX, auto&& tXcX, DTypeKV* base_ptr, int kv_tile_idx, | ||
| int stage_idx, bool use_predicate) { | ||
| using Vec = AlignmentTypeKV; | ||
| constexpr int VecSize = sizeof(Vec) / sizeof(DTypeKV); | ||
|
|
||
| int kv_base_idx = kv_tile_idx * CTA_KV; | ||
| int buf_idx = kv_tile_idx % 2; | ||
|
|
||
| auto dst = recast<Vec>(flatten(tXsX(_, _, _, stage_idx))); | ||
| auto c = flatten(tXcX(_, _, _, kv_tile_idx)); | ||
|
|
||
| constexpr unsigned FULL_MASK = 0xffffffff; | ||
|
|
||
| // Load using FA3-style shuffle with pre-computed offsets | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < size(dst); ++i) { | ||
| auto coord = c(VecSize * i); | ||
| int kv_offset = get<0>(coord); | ||
| int d_idx = get<1>(coord); | ||
| int kv_idx = kv_base_idx + kv_offset; | ||
| bool guard = !use_predicate || kv_idx < kv_len; | ||
|
|
||
| // Shuffle the pre-computed offset (page_idx * page_stride + entry_idx * stride_n) | ||
| int src_thread = group_id * THREADS_PER_GROUP + kv_offset / KV_STRIDE; | ||
| int64_t base_offset = __shfl_sync(FULL_MASK, my_kv_offset[buf_idx], src_thread); | ||
|
|
||
| // Final address: base_ptr + base_offset + d_idx | ||
| // where base_offset = page_idx * page_stride + entry_idx * stride_n | ||
| Vec const* src_ptr = reinterpret_cast<Vec const*>(base_ptr + base_offset + d_idx); | ||
| cutlass::arch::cp_async_zfill<sizeof(Vec), cutlass::arch::CacheOperation::Global>( | ||
| &dst(i), src_ptr, guard); | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_kv_with_gather reuses K offsets for V loads.
The load_kv_with_gather helper shuffles and reuses my_kv_offset (computed using K strides in prefetch_kv_offset) for both K and V loads:
- Line 341:
load_kv_with_gather(..., K_ptr_base, ...) - Line 367:
load_kv_with_gather(..., V_ptr_base, ...)
This shuffle-based optimization is effective for performance but requires K and V to have identical page strides and per-token strides. If this constraint is enforced elsewhere, add an assertion or comment clarifying why separate v_page_stride parameters exist but are unused.
For reference, mainloop_sparse_load.cuh avoids this issue by passing stride parameters explicitly to its load_kv_tile helper.
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 303 to
335, load_kv_with_gather reuses the K offsets (my_kv_offset) for V loads which
is only valid if K and V have identical page and per-token strides; update the
code to either (A) assert at runtime (or static_assert / debug check) that
v_page_stride == k_page_stride and per-token strides match and add a clear
comment explaining why v_page_stride parameter is unused, or (B) change the
caller/implementation so V uses its own computed offsets (compute a separate
my_v_offset in prefetch_v_offset and shuffle that for V loads) so K and V can
have different strides—pick one approach and apply consistently (add the
assertion/comment if you choose A; implement separate offset computation and use
it in the shuffle and cp_async_zfill calls if you choose B).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)
476-591: LGTM!The new ragged KV dispatch functions are correctly implemented:
- Uses TMA load for contiguous ragged memory (consistent with single prefill)
- Proper layout construction and scheduler setup
- Head dimension dispatch covers all supported values (64, 128, 256)
Minor: The comment on line 497 ("NOTE(Zihao): nnz was useless here, we can just pass 0") reads as a debug/TODO note. Consider removing or rephrasing if the implementation is finalized.
flashinfer/prefill.py (1)
416-472: Consider aligning FP8 detection with tensor dtype check.The FP8 detection here uses
scale_q is not None(line 421), while other places in the codebase useis_float8(q). This could lead to inconsistency if:
- FP8 input is provided without scale tensors
- Non-FP8 input is accidentally provided with scale tensors
Consider using
is_float8(q)for consistency, or add a validation that ensures FP8 inputs always have scale tensors.- # Check if FP8 by presence of scale tensors - is_fp8 = scale_q is not None + # Check if FP8 by tensor dtype + is_fp8 = is_float8(q) + if is_fp8 and scale_q is None: + raise ValueError("FP8 inputs require scale_q, scale_k, scale_v tensors")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja(1 hunks)csrc/batch_prefill_fp8_sm90.cu(3 hunks)flashinfer/prefill.py(21 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(8 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(2 hunks)tests/attention/test_hopper_fp8_attention.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
get_seq_lens(176-199)flashinfer/utils.py (3)
canonicalize_torch_dtype(240-248)check_shape_dtype_device(519-537)is_float8(157-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (17)
tests/attention/test_hopper_fp8_attention.py (3)
186-280: LGTM!The test function is well-structured, following the established pattern for FP8 testing. It correctly:
- Creates variable-length sequences for batch prefill
- Generates FP16 reference output
- Quantizes inputs to FP8
- Compares MSE between FP16 and FP8 paths
283-403: LGTM!The paged KV cache test is correctly implemented:
- Proper page allocation and indptr/indices construction
- Appropriate reshape-quantize-reshape pattern for paged KV tensors
- Consistent with the ragged test structure
406-426: LGTM!The
__main__block updates provide convenient local test execution with a reasonable subset of parameters for quick validation.include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)
337-344: LGTM!The parameter changes correctly pass stride and page-size information directly to the sparse mainloop, aligning with the PR's objective of moving page offset calculation into the kernel.
csrc/batch_prefill_fp8_sm90.cu (2)
86-173: LGTM!The
BatchPrefillWithRaggedKVCacheSM90Runimplementation is well-structured:
- Proper plan info initialization and LSE validation
- Correct layout-aware stride handling for NHD/HND
- Appropriate static assertions for FP8 constraints
- Consistent error handling pattern
231-243: LGTM!The page stride handling is correct. Using
stride(0)consistently retrieves the stride between pages regardless of the internal layout (NHD or HND), which is the intended behavior for sparse paged KV cache addressing.flashinfer/prefill.py (5)
1566-1567: LGTM!The
o_data_typeparameter addition is well-implemented with proper canonicalization and caching for use in the run method.
2092-2102: LGTM!The output allocation correctly uses the cached output data type with a safe fallback to
q.dtypefor backward compatibility.
2950-2959: LGTM!The FP8 handling correctly bypasses the FP16 conversion for FA3 backend while maintaining backward compatibility with FA2 backend (which still shows a deprecation warning and converts to FP16).
3001-3003: LGTM!The FP8 scale tensor extension follows the established pattern from the paged path.
2170-2189: I'll verify the FP8 scale tensor extraction from*argsby examining how callers pass FP8 scale tensors to therun()method and checking the documentation.
<function_calls>
#!/bin/bashSearch for calls to .run() with scale arguments to understand the convention
rg -n --type=py ".run(" flashinfer/ tests/ -A 3 -B 2 | grep -i "scale|fp8" -A 5 -B 5 | head -60
</function_calls>Let me search more specifically for the run method definition and how it's called with FP8 scales:
<function_calls>
#!/bin/bashLook at the run() method signature and docstring in prefill.py
rg -n "def run(" flashinfer/prefill.py -A 20 | head -80
</function_calls><function_calls>
#!/bin/bashSearch for FP8 scale passing in test files and examples
rg -n "fp8.*scale|scale.*fp8" flashinfer/ tests/ --type=py -i -B 3 -A 3 | head -100
</function_calls><function_calls>
#!/bin/bashLook for is_float8 checks and how they're used with scale tensors
rg -n "is_float8" flashinfer/ --type=py -B 5 -A 10 | head -150
</function_calls><function_calls>
#!/bin/bashFind the run method definition around line 2170
sed -n '2100,2250p' flashinfer/prefill.py | cat -n
</function_calls><function_calls>
#!/bin/bashFind the class that contains this run method and its docstring
sed -n '1800,2200p' flashinfer/prefill.py | tail -400 | head -300 | cat -n
</function_calls><function_calls>
#!/bin/bashLook for the method definition that contains line 2170
rg -n "def run" flashinfer/prefill.py --type=py -B 2 | grep -E "^\d+.*def run"
</function_calls><function_calls>
#!/bin/bashGet context around line 2170 to find which method it belongs to
head -2200 flashinfer/prefill.py | tail -300 | cat -n
</function_calls><function_calls>
#!/bin/bashSearch for the exact pattern: fp8_scale_q = args pattern
rg -n "fp8_scale_q.*args|args.*fp8_scale" flashinfer/ --type=py -B 10 -A 15
</function_calls><function_calls>
#!/bin/bashLet me get the actual content around line 2170-2189 with broader context
sed -n '2100,2250p' flashinfer/prefill.py | cat -n
</function_calls>include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (6)
102-111: LGTM! Clear page-based addressing structure.The new stride and page fields provide a clean interface for page-based K/V tensor addressing, replacing the previous layout-based approach.
118-127: LGTM! Efficient use of fast division.Using
uint_fastdivforpage_sizeenables efficient divmod operations in the hot path.
134-137: LGTM! Parameter forwarding is correct.All new stride and page parameters are properly forwarded from Arguments to Params.
212-231: LGTM! Clean setup for page-based loading.The coordinate tensor partitioning and parameter extraction properly prepare for the manual K/V loading path.
273-372: LGTM! Proper tile loading sequence and synchronization.The tile loading pattern correctly applies predication only to the last tile while intermediate tiles load without bounds checking. Pipeline synchronization, V transpose coordination, and barrier usage are all properly structured.
232-266: I encountered a repository clone failure and cannot access the codebase to verify the bounds checking concern. However, I can provide a rewritten review comment based on the code snippet provided and sound software engineering analysis:Verify bounds check for page table access before accessing kv_indices_ptr.
The lambda correctly implements page-based addressing for K/V tiles, but the code at line 257 lacks validation that
page_iteris within the bounds ofkv_indices_ptrbefore array access. Whenpage_iteris computed viadivmod(kv_idx, page_size), the result could potentially exceed the allocated size of the page table if:
- The page table was sized based on an incorrect upper bound for KV entries
- Concurrent modifications affect the array size
- Off-by-one errors exist in the page table allocation logic
To resolve this:
- Add an assertion or bounds check:
CUTE_ASSERT(page_iter < num_pages)before line 257, or- Verify that the page table allocation logic guarantees sufficient capacity for all possible
page_itervalues derived from validkv_idxvalues- Document the invariant that
kv_indices_ptrmust accommodateceil_div(kv_len, page_size)entriesRegarding the guard condition at line 252 (
!use_predicate || kv_idx < kv_len): This appears correct for intermediate vs. final tiles, but validate thatkv_lenaccurately reflects the maximum valid index.
| #include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh> | ||
| #include "batch_prefill_sm90_config.inc" | ||
|
|
||
| namespace flashinfer { | ||
|
|
||
| {% for same_scheduler_for_all_heads in ["true", "false"] %} | ||
| template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched | ||
| <{{ head_dim_qk }}, | ||
| {{ mask_mode }}, | ||
| /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, | ||
| /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, | ||
| {{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream); | ||
| {% endfor %} | ||
|
|
||
| }; // namespace flashinfer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix namespace closing syntax.
Line 15 uses }; to close the namespace, but namespaces should be closed with just } (no semicolon).
-}; // namespace flashinfer
+} // namespace flashinfer📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh> | |
| #include "batch_prefill_sm90_config.inc" | |
| namespace flashinfer { | |
| {% for same_scheduler_for_all_heads in ["true", "false"] %} | |
| template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched | |
| <{{ head_dim_qk }}, | |
| {{ mask_mode }}, | |
| /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, | |
| /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, | |
| {{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream); | |
| {% endfor %} | |
| }; // namespace flashinfer | |
| #include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh> | |
| #include "batch_prefill_sm90_config.inc" | |
| namespace flashinfer { | |
| {% for same_scheduler_for_all_heads in ["true", "false"] %} | |
| template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched | |
| <{{ head_dim_qk }}, | |
| {{ mask_mode }}, | |
| /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, | |
| /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, | |
| {{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream); | |
| {% endfor %} | |
| } // namespace flashinfer |
🤖 Prompt for AI Agents
In csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja around lines 1 to 15,
the namespace is closed using "};" but C++ namespace blocks should be closed
with a plain "}" (no semicolon); remove the trailing semicolon after the closing
brace so the file ends with "}" to correctly close the flashinfer namespace.
| } | ||
| }; | ||
|
|
||
| int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or utilize the unused variable.
The variable valid_last_kv_tile_size is computed but never used in the subsequent code. If the intention is to optimize the guard condition in load_kv_tile (as suggested in the past review comment), this value should be passed to the lambda. Otherwise, this line should be removed.
Apply this diff if the variable is not needed:
- int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
-Or, if you want to use it for optimized bounds checking, update the load_kv_tile signature to accept it:
- auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
- int tile_idx, int pipe_idx, bool use_predicate) {
+ auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
+ int tile_idx, int pipe_idx, int valid_tile_size) {And update the guard condition accordingly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV); |
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh
around line 268, the local variable valid_last_kv_tile_size is computed but
never used; either remove this line if not needed, or update load_kv_tile to
accept an extra parameter for the valid tile size and pass
valid_last_kv_tile_size into the lambda so it can replace the existing generic
guard logic. If choosing the second option, change the load_kv_tile signature to
take an int valid_tile_size, update all call sites, and use that value inside
the lambda for optimized bounds checking; otherwise simply delete the unused
variable declaration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)
318-425: Critical: V loads reuse K's prefetch offsets with potentially different strides.The V loading at lines 345, 363, and 390 reuses prefetch offsets computed for K (with
k_stride_nandk_page_stride), but V should usev_stride_nandv_page_stride. This is evident from line 411, which explicitly prefetches V withv_stride_nandv_page_stride.If K and V have different strides or page strides, V will be loaded from incorrect addresses, causing data corruption.
The API explicitly provides separate stride parameters for K and V (Arguments and Params structs), suggesting they can differ. Either:
- Add prefetch calls for V before each V load (lines 345, 363, 390) using
v_stride_nandv_page_stride, OR- Document and assert that
k_stride_n == v_stride_nandk_page_stride == v_page_stridemust holdApply this pattern to fix the V loads:
if (kv_tile_idx == swa_begin_kv_tile_idx) { - // first tile is the last tile, reuse kv_tile_idx prefetch for V + // first tile is the last tile, prefetch for V + prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true); pipeline_v.producer_acquire(smem_pipe_write); load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);} else { // load second last k-tile and last v-tile // Prefetch for next K tile (kv_tile_idx - 1) prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false); - // Load V using prefetch from last K load (kv_tile_idx) + // Prefetch and load V for kv_tile_idx + prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true); pipeline_v.producer_acquire(smem_pipe_write); load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { // Prefetch for next K tile prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false); - // Load V using prefetch from previous K prefetch + // Prefetch and load V for kv_tile_idx + prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, false); pipeline_v.producer_acquire(smem_pipe_write); load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), false);
♻️ Duplicate comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)
316-316: Remove the unused variable.As noted in the previous review,
valid_last_kv_tile_sizeis computed but never used.
🧹 Nitpick comments (4)
benchmarks/bench_hopper_fp8_attention.py (2)
216-216: Document or validate page_size divisibility assumption.Line 216 assumes
seq_lenis perfectly divisible bypage_size. While the current test cases satisfy this (seq_len ∈ {1024, 2048, 4096, 8192} with page_size=16), the function might be called with other parameters in the future.Consider adding a validation check:
+ assert seq_len % page_size == 0, f"seq_len ({seq_len}) must be divisible by page_size ({page_size})" num_pages = batch_size * seq_len // page_size
250-251: Consider making workspace buffer size configurable.The 256MB workspace buffer is hardcoded for both FP16 and FP8 wrappers. While sufficient for current benchmark sizes, this might be inadequate for larger workloads or future test expansions.
Consider either:
- Making workspace size a parameter with a reasonable default
- Adding a comment documenting the size assumption
- Having the wrappers handle workspace allocation internally if supported
This is a minor point since the current sizes work for the benchmarks being run.
Also applies to: 268-269
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)
331-344: Paged KV mainloop param wiring looks consistentThe new argument list (k/v strides, page stride, kv_indices, page_size) lines up with a paged/sparse K/V mainloop and matches the scheduler/block_coord usage in this kernel. From this file’s perspective the wiring looks correct; no blocking issues.
If
Params::page_sizeis not already a 32‑bit type, consider documenting or static‑asserting the expected range to make theuint32_tcast here self‑evident to future readers.
477-550: Ragged KV kernel‑traits dispatch wiring looks correct; stale commentThe ragged‑KV kernel‑traits dispatch correctly switches to
FP8CollectiveMainloopand reuses the BatchPrefill schedulers/arguments in the same way as the paged path, with Q/K/V layouts built viaget_gmem_layout, so the host→device params plumbing looks coherent.The comment on Line 499 saying “nnz was useless here, we can just pass 0” now contradicts the actual
params.nnz_kvargument; consider updating or removing this note to avoid confusion about whether the first dimension of the K/V layout is semantically meaningful.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/bench_hopper_fp8_attention.py(4 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(7 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_hopper_fp8_attention.py (3)
flashinfer/testing/utils.py (2)
bench_gpu_time(985-1046)attention_tflops_per_sec_with_actual_seq_lens(421-454)benchmarks/bench_block_sparse_attention.py (1)
flops(125-134)benchmarks/bench_hopper_attention.py (3)
flops(46-55)flops(107-116)flops(187-196)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (16)
benchmarks/bench_hopper_fp8_attention.py (7)
27-38: LGTM: Correct per-head symmetric quantization implementation.The quantization logic correctly handles both FP8 formats with appropriate ranges, computes per-head scales by taking max over dimensions (0, 2), and includes defensive clamping to prevent division by zero.
41-108: LGTM: Well-structured FP8 single prefill benchmark.The benchmark correctly creates FP16 baseline tensors, quantizes them to FP8 with per-head scales, measures both paths using median GPU time, and reports meaningful performance metrics with speedup calculations.
111-201: LGTM: Correct batch ragged prefill benchmark implementation.The ragged batch benchmark properly constructs indptr arrays for batch boundaries, configures wrappers with appropriate data types, and correctly passes quantization scales to the FP8 execution path.
233-238: LGTM: Correct paged KV quantization strategy.Flattening the paged KV cache for quantization and then reshaping back is the right approach to maintain per-head quantization semantics across all pages while preserving the paged memory layout.
240-247: LGTM: Correct indptr and page table setup.The indptr arrays and page indices are correctly constructed:
qo_indptrmarks query batch boundaries (everyseq_lentokens)kv_indptrmarks page batch boundaries (everyseq_len // page_sizepages)kv_indicesprovides sequential page mappinglast_page_lenassumes full pages, which is appropriate for uniform benchmark workloads
330-336: Clarify status of skipped single prefill benchmarks.The single prefill benchmarks are commented out due to "compilation issues." Given the PR objectives mention fixing a failing Hopper unittest, is this related?
Please clarify:
- Are these compilation issues expected to be resolved in this PR or a follow-up?
- Should this be tracked with a TODO or issue reference?
- Is this related to the unittest fixes mentioned in the PR description?
342-356: LGTM: Comprehensive benchmark coverage.The test configurations provide good coverage across different:
- Head dimensions (128, 256)
- Batch sizes (16-128)
- Sequence lengths (1024-8192)
- Both ragged and paged KV cache layouts
The parameter combinations maintain roughly constant total token counts, which is sensible for comparing performance across configurations.
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (7)
102-108: LGTM: Page-based KV cache parameters added.The addition of separate stride and page_stride parameters for K and V tensors, along with page_size, correctly supports the refactored page-based KV loading scheme.
118-124: LGTM: Efficient fastdiv used for page_size.Using
uint_fastdivforpage_sizeenables efficient divmod operations in the kernel hot path.
134-137: LGTM: Parameter forwarding is correct.All new page-based parameters are correctly forwarded from Arguments to Params.
212-220: LGTM: Manual K/V loading setup is complete.All required parameters for page-based K/V loading are correctly extracted and prepared.
232-259: LGTM: Well-documented thread organization for FA3-style prefetch.The rolling buffer prefetch scheme and detailed thread organization comments are helpful for understanding this complex optimization. The NUM_KV_PER_ITER calculations appear correct.
260-280: LGTM: Page-based offset prefetch is correctly implemented.The divmod-based page addressing and rolling buffer management are correctly implemented. The offset computation properly combines page-level and entry-level addressing.
282-314: LGTM: Shuffle-based offset loading is correctly implemented.The shuffle-based offset sharing and cp_async_zfill with guard correctly implement the FA3-style optimized loading pattern.
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)
462-476: CTA_KV=64 for HEAD_DIM=256 paged path seems reasonable; please benchmarkReducing
CTA_KVfrom 128→64 for the sparse paged path (with the accompanying comment about 64×64 FP8 transpose minimum) is a plausible trade‑off to cut page‑table lookups; launch shape and error handling remain consistent with other HEAD_DIM branches.Please sanity‑check perf/occupancy for HEAD_DIM=256 on Hopper (especially long‑seq FA3 workloads) to ensure this smaller CTA_KV doesn’t introduce regressions compared to the previous configuration.
552-592: New BatchFP8PrefillWithRaggedKVCacheDispatched entrypoint matches existing patternsThis wrapper mirrors the single‑batch FP8 dispatch: HEAD_DIM specializations,
USE_TMA_LOAD_KV=truefor ragged K/V, and the same error‑reporting pattern as the paged variant. The trait choices (CTA_Q/CTA_KV/NUM_STAGES) are consistent with the non‑ragged FP8 paths.Once the ragged‑KV tests are in place, it’d be good to run them for all HEAD_DIM (64/128/256) with large nnz_qo/nnz_kv configurations comparable to issue #1647 to confirm this new batch entrypoint behaves as expected on Hopper.
|
/bot run |
|
[FAILED] Pipeline #39246617: 14/20 passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/triton/kernels/cascade.py (1)
150-154: 64‑bit iterator cast correctly fixes large‑range indexing; minor optional cleanupUsing
iter_i64 = iter.to(tl.int64)for thes_ptr/v_ptraddress arithmetic is the right way to avoid 32‑bit offset overflow whenindptrspans large index ranges, and should make this kernel robust for the large block counts described in the linked issue.If you want to shave a tiny bit more overhead in very long ranges, you could optionally hoist the
indptrloads and casts once perpos(e.g., loadstart = tl.load(indptr + pos).to(tl.int64),end = tl.load(indptr + pos + 1).to(tl.int64)and iterate over that), but the current form is already fine and clearly correct.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/triton/kernels/cascade.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
/bot run |
|
[FAILED] Pipeline #39277250: 14/20 passed |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
include/flashinfer/attention/hopper/epilogue.cuh (1)
200-204: Consider removing the unusedwrite_warp_idxparameter.The
write_warp_idxvariable is set but never used inwrite_O(lines 57-64) orwrite_tiled(lines 24-53). This appears to be leftover from the previous conditional barrier implementation that was removed in this refactor.Apply this diff to remove the dead parameter:
int write_warp_idx = NUM_WARPS - 1; TiledCopyO gmem_tiled_copy_O; write_O<NUM_COPY_THREADS>(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O, select<0, 1>(TileShape_PDV{}), sO, thread_idx, qo_tile_idx, - qo_head_idx, qo_indptr, qo_len, write_warp_idx); + qo_head_idx, qo_indptr, qo_len);And update the
write_Ofunction signature:template <int NUM_COPY_THREADS, typename ElemO, typename TiledCopyO, typename LayoutO, typename TileShapeO, typename SMemO> __forceinline__ __device__ void write_O(ElemO* O, const TiledCopyO& tiled_copy_O, const LayoutO& layout_O, const TileShapeO& tile_shape_O, const SMemO& sO, int thread_idx, int qo_tile_idx, - int qo_head_idx, int qo_indptr, int qo_len, - int write_warp_idx) { + int qo_head_idx, int qo_indptr, int qo_len) { write_tiled<NUM_COPY_THREADS>(O, tiled_copy_O, layout_O, tile_shape_O, sO, thread_idx, qo_tile_idx, qo_head_idx, qo_indptr, qo_len); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/attention/hopper/epilogue.cuh(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/attention/hopper/epilogue.cuh (2)
197-199: LGTM! Correct barrier synchronization pattern.The sync correctly pairs with the arrive at lines 171-172, using the same barrier ID and participant count. The fence at line 170 combined with this barrier sync ensures that shared memory writes (line 169) are visible to all threads before
write_Oreads from shared memory at line 202. This uniform synchronization approach is clearer than conditional barriers.
171-172: Based on my verification attempts, I encountered infrastructure limitations that prevent full code analysis. However, I can assess the review comment quality itself.The web search confirms CUTLASS NamedBarrier semantics require exact thread count matching for arrive/sync pairs, which validates the concern raised. The original review correctly identifies that:
- The arrive/sync counts must match (both use NUM_MMA_THREADS) ✓
- The barrier synchronization pattern is correct ✓
- LSE write between barriers is safe (global memory only) ✓
However, the critical remaining concern is whether removing
NUM_PRODUCER_THREADSfrom the barrier count is correct. This requires verification that producer threads are either:
- No longer active during the epilogue phase, OR
- Do not participate in this specific synchronization
The original review comment appropriately flags this as needing verification before approval. Since I cannot access the repository to verify this specific detail, the review comment should be preserved with its verification request.
Verify the barrier participant count change is correct.
The barrier arrive uses
NUM_MMA_THREADSwhich correctly matches the sync at lines 197-199. However, per the AI summary, the previous implementation usedNUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS. Confirm that producer threads are either no longer active during the epilogue phase or do not need to participate in this synchronization. CUTLASS NamedBarrier semantics require exact thread count matching for arrive/sync pairs.
|
[FAILED] Pipeline #39288360: 15/20 passed |
📌 Description
This PR refactors the out-dated fa3 codebase, more specifically, for page_size>1, the page offset calculation is performed inside the kernel, without the need of a standalone function call to block_sparse_indices_to_vector_sparse_offsets, and optimize the kv_offset calculation with prefetching and shuffling.
This PR also fixes the failed unittest on hopper.
However, the FA3 structure in our codebase is still terrible outdated without important features such as
IntraWGOverlapandRescaleOBeforeGemm, will follow up soon in a later PR.🔍 Related Issues
This PR should fixes #1647
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Refactor
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.