Add cute dsl mla decode op#2743
Conversation
Integrates NVIDIA's CuTe DSL MLA decode kernels (FP16/FP8) for Blackwell SM100 as a new "cute-dsl" backend in trtllm_batch_decode_with_kv_cache_mla(). Key tensor layout insights documented in mla_decode.py: - c_latent/c_rope kernel layout is [page_size, D, total_pages], not [total_tokens, D, 1] — the kernel indexes KV intra-page per physical page - All fake tensor dimensions must be cute.sym_int() (not static Python ints) so cute.assume() receives CuTe Integer types in initialize_workspace() - lse fake tensor needs stride_order=(0,1,2) for stride[0]=1 compile-time constant - Do NOT call .contiguous() after .permute() on q/lse/o tensors — it collapses to row-major, destroying required non-standard strides - Separate sym_kv_batch for KV cache (=1, flat pool) vs query batch (=B) New files: - flashinfer/cute_dsl/mla_helpers.py - flashinfer/cute_dsl/mla_decode_fp16.py - flashinfer/cute_dsl/mla_decode_fp8.py - flashinfer/cute_dsl/mla_decode.py (compilation wrapper + public API) - tests/attention/test_cute_dsl_mla_decode.py (14 tests, all passing)
- Remove unnecessary .contiguous() on page_table transpose by changing fake tensor stride_order from (1,0) to (0,1), matching the original kernel's convention of non-contiguous permute(1,0) - Use torch.full instead of torch.ones * val for block_split_kvs - Remove redundant .contiguous() on workspace buffer slice - Remove redundant .to(int32).contiguous() when seq_lens is already int32 - Eliminate output copy_ by writing kernel output directly into caller's out tensor via permute view (works for both q_len=1 and q_len>1) - Fix output allocation order from (B,H,q_len,D) to (B,q_len,H,D) so permute back to user layout is naturally contiguous, removing .contiguous() - Cache split_kv and workspace_size computation via functools.cache - Remove tensor_api closure wrapper, call compiled_kernel directly - Add host overhead benchmark script Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e permutes Accept contiguous row-major tensors and reinterpret layouts inside the kernel's __call__ via zero-cost cute.make_tensor + cute.make_layout, removing ~10 us of Python-side .permute() overhead per call. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…dsl_mla_decode - Expose is_var_split_kv as a public parameter (default False) to control whether to use per-batch variable split_kv or uniform scalar split_kv, avoiding a torch.full GPU kernel (~5 us) when not needed. - Add workspace_buffer size assertion to catch undersized buffers early. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a CuTe-DSL MLA decode implementation, static tile-scheduler helpers, dtype utilities, tests, and benchmarking/back-end wiring; exposes Changes
Sequence Diagram(s)sequenceDiagram
participant Client as PyTorch Caller
participant API as cute_dsl_mla_decode
participant Cache as Param/Kernel Cache
participant Compiler as CUTLASS Compiler
participant WS as Workspace Manager
participant Kernel as CUDA/CuTe Kernel
Client->>API: call(query, kv_cache, params...)
API->>API: validate & normalize inputs
API->>Cache: lookup split_kv & workspace_size for batch config
alt cached
Cache-->>API: return params
else
API->>Compiler: compute split_kv & workspace_size
Compiler-->>Cache: store params
Cache-->>API: return params
end
API->>Cache: lookup compiled kernel for config
alt cached
Cache-->>API: return kernel
else
API->>Compiler: compile kernel (symbolic shapes)
Compiler-->>Cache: store compiled kernel
Cache-->>API: return kernel
end
API->>WS: allocate workspace
WS-->>API: workspace ptr
API->>Kernel: launch with tensors & workspace
Kernel-->>API: produce output
API-->>Client: return formatted output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@flashinfer-bot run |
|
/bot run |
|
@limin2021 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
flashinfer/cute_dsl/mla_decode.py (1)
45-54: Don’t key the compile cache on unused dynamic dims.
num_headsandseq_len_qnever feedKernelClass(...), and the fake tensors already model both dimensions withcute.sym_int(). Keeping them in a@functools.cachekey recompiles the same kernel for every(H, q_len)pair.🔧 Proposed cleanup
def _get_compiled_mla_kernel( is_fp8: bool, page_size: int, - num_heads: int, - seq_len_q: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, ) -> Tuple[Callable, object]: @@ tensor_api, kernel_cls = _get_compiled_mla_kernel( is_fp8=is_fp8, page_size=page_size, - num_heads=H, - seq_len_q=q_len, is_persistent=is_persistent, is_var_seq=is_var_seq, is_var_split_kv=is_var_split_kv, )Also applies to: 85-93, 360-368
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 325-333: The slice can produce an undersized tensor silently;
before slicing workspace_buffer validate it is on the correct device, has dtype
torch.uint8, and has at least max(workspace_size, 1) bytes; if not, raise a
clear error (or allocate/resize) so the kernel never receives undersized scratch
memory. Locate calls around
BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size(...) and the
variables workspace_buffer and workspace_bytes, check device equality
(workspace_buffer.device vs expected device from inputs), verify
workspace_buffer.dtype is torch.uint8, and assert workspace_buffer.numel() >=
max(workspace_size, 1) before performing the workspace_buffer[:workspace_size]
slice. Ensure error messages reference workspace_size and actual buffer size for
easier debugging.
In `@flashinfer/cute_dsl/mla_helpers.py`:
- Around line 264-287: The MLIR serialization round-trip omits the scheduler's
is_valid state so non-persistent schedulers become valid after deserialization;
update __extract_mlir_values__ to include self.is_valid (e.g., append a boolean
representation) and update __new_from_mlir_values__ to consume and restore that
boolean into the new object's is_valid, adjusting the assert(len(values)) and
the slicing offsets accordingly; ensure advance_to_next_work (which sets
is_valid=False for non-persistent schedulers) continues to work with the new
serialized field so deserialized instances preserve exhausted/non-exhausted
state.
In `@flashinfer/mla.py`:
- Around line 771-790: In the backend == "cute-dsl" branch in mla.py you must
reject unsupported knobs instead of silently dropping them: before calling
cute_dsl_mla_decode, validate the parameters sparse_mla_top_k, sinks, and
skip_softmax_threshold_scale_factor and raise a clear error (e.g., ValueError or
NotImplementedError) if any are set to non-default values; keep the existing
call to cute_dsl_mla_decode (passing query, kv_cache, workspace_buffer,
kv_lora_rank, qk_rope_head_dim, block_tables, seq_lens, max_seq_len,
softmax_scale, output_scale, out) but ensure unsupported options are checked and
rejected with a descriptive message referencing those option names.
- Around line 774-788: The cute-dsl path is receiving a wrongly
scaled/statically-cast value because bmm1_scale and bmm2_scale are converted to
Python floats (via .item()) after bmm1_scale was multiplied by log2e; fix by
passing tensor values (not .item()) to cute_dsl_mla_decode so
CUDA-graph/dynamic-tensor semantics are preserved and the backend sees the
correct numeric scale: for softmax_scale pass the original unmultiplied tensor
(or divide the current bmm1_scale by log2e) as a tensor rather than float, and
likewise stop calling .item()/float() for output_scale (bmm2_scale) — update the
cute_dsl_mla_decode call to use tensor softmax_scale=bmm1_scale_tensor and
output_scale=bmm2_scale_tensor (referencing cute_dsl_mla_decode, bmm1_scale,
bmm2_scale).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9fd8c3ec-8099-47fb-891c-f262ddfb6d3c
📒 Files selected for processing (7)
flashinfer/cute_dsl/__init__.pyflashinfer/cute_dsl/mla_decode.pyflashinfer/cute_dsl/mla_decode_fp16.pyflashinfer/cute_dsl/mla_decode_fp8.pyflashinfer/cute_dsl/mla_helpers.pyflashinfer/mla.pytests/attention/test_cute_dsl_mla_decode.py
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities by incorporating highly optimized CuTe DSL MLA decode kernels for Blackwell SM100 architectures. The integration focuses on improving efficiency and performance for multi-head latent attention operations, particularly by streamlining data handling and supporting mixed-precision computations. This update provides a specialized backend that leverages advanced GPU features, leading to faster and more efficient inference for large language models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request integrates NVIDIA's CuTe DSL MLA decode kernels for Blackwell SM100 GPUs into FlashInfer, adding support for both FP16 and FP8 data types. The changes are well-structured, including new kernel wrappers, helper files, and comprehensive tests. I've identified a minor maintainability issue where the kernel class for calculating split_kv and workspace_size is hardcoded, and I've provided a suggestion to make this dynamic based on the data type. Overall, this is a solid contribution that extends FlashInfer's capabilities to new hardware.
Note: Security Review did not run due to the size of the PR.
- Add get_split_kv_simplified() that computes split_kv without max_seq_len - Remove is_var_split_kv from public API (hardcode False), eliminating torch.full GPU kernel overhead per call - Remove unused bench_cute_dsl_mla_host_overhead.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (1)
flashinfer/cute_dsl/mla_decode.py (1)
304-309:⚠️ Potential issue | 🟠 MajorValidate
required_workspacebefore slicing.When
workspace_size == 0, an empty buffer satisfies the current assert, butworkspace_buffer[: max(workspace_size, 1)]still produces a 0-length tensor. This path also still assumes same-devicetorch.uint8storage even though the check is phrased in bytes.🧯 Suggested fix
- assert workspace_buffer.numel() >= workspace_size, ( - f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " - f"need {workspace_size} bytes" - ) - workspace_bytes = workspace_buffer[: max(workspace_size, 1)] + if workspace_buffer.device != query.device: + raise ValueError("workspace_buffer must be on the same device as query") + if workspace_buffer.dtype != torch.uint8 or workspace_buffer.dim() != 1: + raise ValueError("workspace_buffer must be a 1-D torch.uint8 tensor") + required_workspace = max(workspace_size, 1) + if workspace_buffer.numel() < required_workspace: + raise ValueError( + f"workspace_buffer too small: need {required_workspace} bytes, got {workspace_buffer.numel()}" + ) + workspace_bytes = workspace_buffer[:required_workspace]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/mla_decode.py` around lines 304 - 309, Validate the requested workspace size and buffer properties before slicing: assert that workspace_size (the required_workspace) is non-negative and that workspace_buffer.numel() (taking element size into account) is >= workspace_size, and also assert workspace_buffer.dtype is torch.uint8 and on the expected device; handle workspace_size == 0 explicitly (e.g., set workspace_bytes = workspace_buffer[:0] or return an empty view) instead of using max(workspace_size, 1) so you don't assume a 1-byte element, and replace the existing slice workspace_buffer[: max(workspace_size, 1)] with a slice that respects the validated workspace_size and the buffer's dtype/device in functions/variables such as workspace_buffer and workspace_size.
🧹 Nitpick comments (1)
tests/attention/bench_cute_dsl_mla_host_overhead.py (1)
114-133: Profile the current decode path, not the removed permute-based path.Most of these sections still time the old Python-side permutes/transposes and the
is_var_split_kv=Truelaunch shape, whilecute_dsl_mla_decode()now uses row-major inputs andis_var_split_kv=False. The section breakdown will be misleading until this helper mirrors the current wrapper or is renamed as a legacy-path profiler.Also applies to: 149-215
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/bench_cute_dsl_mla_host_overhead.py` around lines 114 - 133, The profiling helpers (query_reshape, kv_reshape, page_table_transpose) are still timing the old permute/transpose path and the is_var_split_kv=True layout, but cute_dsl_mla_decode() now expects row-major inputs with is_var_split_kv=False; update these helpers (and the analogous blocks at 149-215) to produce and measure the same input layout and shapes used by cute_dsl_mla_decode(), or rename them to indicate they profile the legacy permute-based path; specifically modify query_reshape, kv_reshape, and page_table_transpose to return row-major tensors matching the decode wrapper (remove permute/.t() usage and use the same slicing/shape conventions as cute_dsl_mla_decode), and ensure any measure calls use is_var_split_kv=False so the profiler reflects the current decode path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 295-296: Normalize the auxiliary index tensors before dispatch:
convert block_tables (used as page_table_k) to the expected integer dtype and
ensure it's contiguous, and likewise make page_table_fake, cache_seqs_fake and
the reused seq_lens contiguous (and cast to the wrapper-expected dtype, e.g.,
torch.int64) before passing them into the TVM/FFI call; update assignments
around page_table_k, page_table_fake, cache_seqs_fake and the seq_lens usage
(also the similar site at the block around lines 327-328) to use .to(dtype=...,
copy=False) and .contiguous() so mismatched dtypes or strided tensors are
rejected at the wrapper boundary instead of failing deep in the FFI.
- Around line 63-71: The cache for _get_compiled_mla_kernel is fragmenting on
num_heads and seq_len_q even though those axes are compiled as cute.sym_int()
and do not affect cute.compile(); remove num_heads and seq_len_q from the
`@functools.cache` function signature (and any cache keys) so the cache is keyed
only by true compile-time flags (e.g., is_fp8, page_size, is_persistent,
is_var_seq, is_var_split_kv), keep using cute.sym_int() inside the compiled
kernel to represent H and q_len as symbolic ints, and ensure callers still pass
num_heads/seq_len_q only when invoking the returned Callable rather than as
cache inputs.
In `@tests/attention/bench_cute_dsl_mla_host_overhead.py`:
- Around line 84-86: The profiler function profile_sections currently passes an
outdated fifth positional argument max_seq_len into the helper
_get_split_kv_and_workspace_size, causing a TypeError; remove max_seq_len from
the helper call(s) inside profile_sections (and the same redundant argument
usage later around the other call at lines ~137-140) so that
_get_split_kv_and_workspace_size is called with the current four parameters
only, leaving the rest of profile_sections (kv_cache, workspace_buffer,
kv_lora_rank, qk_rope_head_dim, block_tables, seq_lens, softmax_scale,
output_scale, num_iters) unchanged.
- Around line 88-96: Remove the stale unused imports causing Ruff F401 by
deleting the unused symbols from the import block: specifically drop Float32,
Int32 and cutlass (and also remove get_num_sm if it is not referenced
elsewhere). Keep only the actually used symbols such as
_get_compiled_mla_kernel, _get_split_kv_and_workspace_size, _LATENT_DIM,
_ROPE_DIM, _MMA_QK_TILER_MN, _MAX_ACTIVE_CLUSTERS, and
BlackwellMultiHeadLatentAttentionForwardFP16 so the import statement no longer
triggers unused-import lint errors.
---
Duplicate comments:
In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 304-309: Validate the requested workspace size and buffer
properties before slicing: assert that workspace_size (the required_workspace)
is non-negative and that workspace_buffer.numel() (taking element size into
account) is >= workspace_size, and also assert workspace_buffer.dtype is
torch.uint8 and on the expected device; handle workspace_size == 0 explicitly
(e.g., set workspace_bytes = workspace_buffer[:0] or return an empty view)
instead of using max(workspace_size, 1) so you don't assume a 1-byte element,
and replace the existing slice workspace_buffer[: max(workspace_size, 1)] with a
slice that respects the validated workspace_size and the buffer's dtype/device
in functions/variables such as workspace_buffer and workspace_size.
---
Nitpick comments:
In `@tests/attention/bench_cute_dsl_mla_host_overhead.py`:
- Around line 114-133: The profiling helpers (query_reshape, kv_reshape,
page_table_transpose) are still timing the old permute/transpose path and the
is_var_split_kv=True layout, but cute_dsl_mla_decode() now expects row-major
inputs with is_var_split_kv=False; update these helpers (and the analogous
blocks at 149-215) to produce and measure the same input layout and shapes used
by cute_dsl_mla_decode(), or rename them to indicate they profile the legacy
permute-based path; specifically modify query_reshape, kv_reshape, and
page_table_transpose to return row-major tensors matching the decode wrapper
(remove permute/.t() usage and use the same slicing/shape conventions as
cute_dsl_mla_decode), and ensure any measure calls use is_var_split_kv=False so
the profiler reflects the current decode path.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 100622e9-26bc-47bd-b2ed-b5366850c5d2
📒 Files selected for processing (4)
flashinfer/cute_dsl/mla_decode.pyflashinfer/cute_dsl/mla_decode_fp16.pyflashinfer/cute_dsl/mla_decode_fp8.pytests/attention/bench_cute_dsl_mla_host_overhead.py
- Add torch_to_cutlass_dtype() in utils.py for torch.dtype -> cutlass dtype conversion - Extend mla_decode_fp16.py can_implement() to accept BFloat16 - Refactor mla_decode.py to support float16/bfloat16/float8_e4m3fn via dtype-aware dispatch - Add BFloat16 parametrization to test_cute_dsl_mla_decode_fp16 test - Add backend parameter to bench_trtllm_gen_mla.py benchmark - Remove unused bench_cute_dsl_mla_host_overhead.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)
311-326: Strengthen the public API test beyond a shape check.This only proves that the new
flashinfer.mlabackend branch returns a tensor of the expected size. Since that path is part of the new surface in this PR, please compareoutagainsttorch_reference_mla(or the directcute_dsl_mla_decoderesult) so scale/dtype wiring bugs do not slip through.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 311 - 326, The test only asserts shape for the trtllm_batch_decode_with_kv_cache_mla call; instead compute a reference result (using torch_reference_mla or calling cute_dsl_mla_decode with the same inputs) and assert numeric equivalence: compare out to the reference tensor with torch.allclose (or torch.testing.assert_allclose) using sensible rtol/atol for the dtype to catch scale/dtype wiring bugs. Ensure you use the same inputs (query, kv_cache, workspace_buffer, block_tables, seq_lens, max_seq_len, bmm1_scale, bmm2_scale) and keep the existing shape check if desired.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Around line 150-154: The benchmark currently excludes torch.bfloat16 for the
"cute-dsl" backend (setting dtypes in the args.backend conditional), but
cute_dsl/mla_decode.py now supports bfloat16; update the dtypes assignment so
that torch.bfloat16 is included in the dtypes list for the "cute-dsl" branch
(and revise the inline comment to remove the incorrect "only supports float16"
note). Locate the args.backend conditional that sets dtypes in
bench_trtllm_gen_mla.py and add torch.bfloat16 alongside the existing dtypes for
the "cute-dsl" case.
In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 307-308: Remove the unused local is_fp8 in mla_decode.py: delete
the line that sets is_fp8 = q_dtype == torch.float8_e4m3fn since q_dtype already
controls the dtype dispatch and is_fp8 is never referenced; leave q_dtype =
query.dtype intact and run pre-commit to ensure Ruff F841 no longer flags the
dead variable.
- Around line 296-305: Replace the runtime-conditional asserts in the checks for
input dtypes and shapes with explicit exceptions so they cannot be removed with
python -O: validate that query.dtype is in _SUPPORTED_DTYPES and raise a
TypeError with a clear message if not; check kv_cache.dtype equals query.dtype
and raise a TypeError if it does not; validate shapes B, q_len, H, D_qk and that
D_qk == kv_lora_rank + qk_rope_head_dim and raise a ValueError with a
descriptive message on mismatch; also enforce kv_lora_rank == _LATENT_DIM and
qk_rope_head_dim == _ROPE_DIM with ValueError if violated. Apply the same
replacements to the analogous checks around the other block referenced (the
assertions at lines 349-353) so all guards use stable exceptions instead of
assert.
- Around line 310-323: Before reinterpreting memory for the kernel, validate
that tensors use a dense row-major layout: check that query, normalized kv_cache
(after handling 4D squeeze) and out are contiguous and have expected
strides/dimensions (for a 4D kv_cache ensure the second dimension == 1 before
squeeze); if any check fails, raise a clear error explaining the required
compact layout. Add these checks immediately before the blocks that split query
into q_latent_k/q_rope_k and kv_cache into c_latent_k/c_rope_k (and the
analogous checks in the later block around lines 358-366), and include the
variable names query, kv_cache, out and a reference to the kernel's __call__
reinterpretation in the error message so callers know why the tensor must be
compact.
---
Nitpick comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 311-326: The test only asserts shape for the
trtllm_batch_decode_with_kv_cache_mla call; instead compute a reference result
(using torch_reference_mla or calling cute_dsl_mla_decode with the same inputs)
and assert numeric equivalence: compare out to the reference tensor with
torch.allclose (or torch.testing.assert_allclose) using sensible rtol/atol for
the dtype to catch scale/dtype wiring bugs. Ensure you use the same inputs
(query, kv_cache, workspace_buffer, block_tables, seq_lens, max_seq_len,
bmm1_scale, bmm2_scale) and keep the existing shape check if desired.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ffbe60f4-3f20-4127-accb-5a06707bee0b
📒 Files selected for processing (6)
benchmarks/bench_trtllm_gen_mla.pyflashinfer/cute_dsl/__init__.pyflashinfer/cute_dsl/mla_decode.pyflashinfer/cute_dsl/mla_decode_fp16.pyflashinfer/cute_dsl/utils.pytests/attention/test_cute_dsl_mla_decode.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/cute_dsl/init.py
# Conflicts: # tests/attention/test_trtllm_gen_mla.py
|
/bot run |
|
[FAILED] Pipeline #46310895: 7/20 passed |
Requested changes have been made. Dismissing "requested changes"
Resolve conflict in flashinfer/mla/_core.py by keeping both is_var_seq and uses_shared_paged_kv_idx parameters. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
/bot run |
- Allow BFloat16 output for FP8 input (matching trtllm-gen backend default) - FP16/BF16 input defaults to same dtype output; FP8 input defaults to BF16 output - Add out_dtype parameter to cute_dsl_mla_decode for explicit override - Add uses_shared_paged_kv_idx=False validation for cute-dsl backend - Skip unsupported 3D page table tests for cute-dsl Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
/bot run |
|
[FAILED] Pipeline #46949378: 6/20 passed |
bkryu
left a comment
There was a problem hiding this comment.
Thanks @limin2021 -- Requesting one change in the unit test file to skip outside of SM100f
| if backend == "cute-dsl": | ||
| if compute_capability[0] < 10: | ||
| pytest.skip("cute-dsl MLA requires SM100+") |
There was a problem hiding this comment.
@limin2021 now most tests seem to be passing with the latest updates to CuTe DSL.
Only failure I am noticing is that we are failing this test on SM120 because the CuTe DSL kernel you added is for SM100f. Can you add a skip here?
The tcgen05 MMA operations only support SM100-SM110. Tighten arch checks so SM120a (and above) are correctly skipped, and SM110 is correctly allowed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
/bot run |
Fix has been delivered. Dismissing review request
|
/bot run |
|
[FAILED] Pipeline #47024987: 13/20 passed |
📌 Description
Integrate NVIDIA's CuTe DSL MLA (Multi-Head Latent Attention) decode kernels for Blackwell SM100 into FlashInfer, supporting both BF16/FP16 and FP8 dtypes.
Test plan
Funtionality:
All 396 configs PASSED with 0 failures.
Test matrix:
Performance:
1. FP8 fixed-len (is_var_seq=False → persistent)
2. FP8 var-seqlen (is_var_seq=True → non-persistent)
3. BF16 fixed-len (is_var_seq=False → persistent)
4. BF16 var-seqlen (is_var_seq=True → non-persistent)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Public API
Tests
Benchmarking