Skip to content

[Feat] Add num_heads < 128 support for mla decode kernel#3309

Open
Observer007 wants to merge 9 commits into
flashinfer-ai:mainfrom
Observer007:feat/mla_num_heads_64
Open

[Feat] Add num_heads < 128 support for mla decode kernel#3309
Observer007 wants to merge 9 commits into
flashinfer-ai:mainfrom
Observer007:feat/mla_num_heads_64

Conversation

@Observer007
Copy link
Copy Markdown
Contributor

@Observer007 Observer007 commented May 13, 2026

📌 Description

Features addition:

  1. Add num_heads < 128 support, fold seqlen_q into num_heads to fill the M tile of MLA.
  2. return lse support.

Performance enhancement:

  1. mainly improve the fp8 mla decode.

Performance collection before & after this commit with MTP==4.
Blackwell B200, no freq lock. seqlen_k = 80k.

num_heads=16 (F=4 fold)
    bs   Before  This MR   speedup
       1       44       30    1.46x
       2       72       36    1.99x
       4      138       53    2.60x
       8      263       85    3.09x
      16      524      152    3.44x
      32     1084      300    3.61x
      64     2256      623    3.62x
      96     3507     1161    3.02x
     128     4624     1403    3.30x
     192     7789     1999    3.90x
     256     9386     2915    3.22x
     384    15215     4440    3.43x
     512    21023     6097    3.45x

num_heads=32 (F=4 fold)
      bs   Before  This MR   speedup
       1       46       31    1.50x
       4      141       55    2.59x
       8      268       90    2.99x
      16      532      157    3.38x
      32     1111      316    3.52x
      64     2309      655    3.53x
      96     3702     1196    3.10x
     128     4493     1345    3.34x
     192     7524     2137    3.52x
     256    10628     3028    3.51x

num_heads=64 (F=2 fold)
      bs   Before  This MR   speedup
       1       79       33    2.41x
       4      167       76    2.20x
       8      278      146    1.90x
      16      545      290    1.88x
      32     1144      581    1.97x
      64     2354     1248    1.89x
      96     3766     1973    1.91x
     128     5224     2652    1.97x
     192     8211     3947    2.08x
     256    10515     5486    1.92x

num_heads=128 (F=1, no fold — sanity check)
      bs   Before  This MR   speedup
       1       48       46    1.05x
       4      152      143    1.06x
       8      296      280    1.06x
       16     580      545    1.07x
       32    1357     1157    1.17x
       64    2729     2496    1.09x
       96    4137     3791    1.09x
      128    5191     4863    1.07x
      192    8186     7474    1.10x
      256   11432    10167    1.12x

🔍 Related Issues

#3287

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Optional log-sum-exp (LSE) output for attention APIs.
    • Spec-decoding (MTP) causal masking support in applicable attention paths.
    • Token-folding optimization for Blackwell GPUs to improve throughput.
  • Tests

    • Expanded coverage across both attention implementations, LSE behavior, masking, sinks/dispatch, and end-to-end API scenarios (including expected error cases).
  • Chores

    • Standardized benchmark q_len grid across backends.

Review Change Stack

pgera and others added 2 commits May 12, 2026 01:49
…te_dsl_impl= (AI-assisted)

PR flashinfer-ai#2805 refactored the monolithic CuTe-DSL MLA decode kernel into a
modular structure and removed the original implementation.  The original
authors want it kept available because the modular path is still
maturing.  Restore it under the cute-dsl backend (no new backend name)
and let the user pick:

  flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
      ..., backend="cute-dsl", cute_dsl_impl="auto" | "modular" | "monolithic")

Layout:
  flashinfer/cute_dsl/attention/
    monolithic/         - restored kernels (verbatim from before flashinfer-ai#2805,
                          relocated to live next to the modular code).
                          Includes the H<128 / Kimi K2.5 fix from flashinfer-ai#3235
                          backported (workspace pads H to max(H, 128)
                          when split_kv != 1; can_implement no longer
                          rejects H<128).
    wrappers/           - existing modular standalone + wrapper.
    mla_dispatch.py     - new dispatcher in front of both impls.

Dispatcher contract:
  - "auto"       (default): monolithic, but auto-promotes to modular
                            when a modular-only feature is requested
                            (currently sinks).
  - "modular"    : strict, always modular.
  - "monolithic" : strict, raises ValueError if a modular-only feature
                   is requested rather than silently substituting.

The dispatcher strips modular-only kwargs (sinks=None) before
forwarding to monolithic, so callers can pass sinks= unconditionally
without breaking the monolithic path.

Sinks support:
  - trtllm_batch_decode_with_kv_cache_mla(sinks=...) on backend=
    "cute-dsl" now constructs an AttentionWithSink variant inside the
    modular standalone, instead of being rejected at the API boundary.
  - AttentionWithSink gains value-based __hash__/__eq__ keyed on
    (data_ptr, shape, dtype) so @functools.cache on _compile_mla_kernel
    correctly reuses compiled kernels across invocations with the same
    sinks tensor.  Without this, a fresh variant per call hashed by
    object identity, JIT-recompiled the kernel on every iteration, and
    made cuda-graph + sinks bench measurements appear to hang.

Tests:
  - test_cute_dsl_mla_decode.py: existing standalone and public-API
    tests now parametrize over modular/monolithic via a cute_dsl_impl
    fixture; new minimal sinks tests pin the auto/modular dispatch
    branches and the monolithic+sinks ValueError contract.  Wrapper
    sinks numerics remain covered by the pre-existing
    test_cute_dsl_mla_decode_attention_sink.
  - test_trtllm_gen_mla.py: comment near the cute-dsl skip refreshed
    to reflect the dispatcher's cute_dsl_impl behaviour.

Bench:
  - bench_trtllm_gen_mla.py grows a focused 6-cell with_sinks=True
    sub-sweep (B in {1,16,128} x S in {1024,8192} at q_len=1, page=64,
    bf16) on top of the existing main sweep, instead of doubling the
    full grid.  Argument list deduplicated into a common_kwargs dict
    so warmup and benchmark calls cannot drift.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a96a1ba0-6788-4384-97bc-4da71d412fea

📥 Commits

Reviewing files that changed from the base of the PR and between 9f0d7c1 and 131dab4.

📒 Files selected for processing (1)
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py

📝 Walkthrough

Walkthrough

Adds fold-S_q packing and fold-aware causal masking to monolithic FP16/FP8 MLA kernels; converts LSE outputs to natural-log; exposes optional LSE buffers via wrapper/API with sinks and cute_dsl_impl plumbing; reworks FP8 dual-group softmax; parametrizes tests; fixes tile-scheduler index and benchmark q_lens grid.

Changes

Monolithic Kernel Folding, Masking, LSE, and Test Updates

Layer / File(s) Summary
FP16 Monolithic Kernel: Folding, Masking Phases, and LSE Support
flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py
Constructor adds num_heads, seq_len_q, fold_sq; compute_fold_sq_ratio() helper; conditional folding remap in __call__; phased masking in compute() with first_mask_tile_idx; softmax() uses apply_mask; epilogue and split-KV write LSE in natural-log; can_implement() updated.
FP8 Monolithic Kernel: 2-Softmax Restructuring, TMEM Ownership, and Fold Support
flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py
Constructor adds folding params and fold ratio; kernel restructured to two compute warp groups, separate QK/PV MMA flows, doubled SMEM exchange, ping-pong correction TMEM flow, fold-aware masking with sentinel, and reduction LSE converted to natural-log.
Wrappers & Public API: Fold Caching, LSE Support, and Sinks Integration
flashinfer/cute_dsl/attention/monolithic/mla_decode.py, flashinfer/cute_dsl/attention/wrappers/batch_mla.py, flashinfer/mla/_core.py
Compute effective fold-aware dims for cached split-KV/workspace sizing; add num_heads/seq_len_q to compiled-kernel cache key and pass fold_sq into kernel options; extend cute_dsl_mla_decode/wrapper to accept sinks, optional lse buffer and return_lse (modular path raises NotImplemented for LSE), and forward cute_dsl_impl/lse/return_lse via _core dispatch.
Testing: Reference Implementation, Parametrization, and Validation
tests/attention/test_cute_dsl_mla_decode.py
torch_reference_mla gains apply_mtp_mask and return_lse; tests parametrized over cute_dsl_impl (modular/monolithic); monolithic-only fold_sq regression added; FP8 tests accept caller-provided 2D LSE buffer and verify identity/shape/dtype; API-level sinks dispatch and error tests added.
Helpers & Housekeeping: Tile Scheduler Coordinate Fix and Benchmark Simplification
flashinfer/cute_dsl/attention/monolithic/mla_helpers.py, benchmarks/bench_trtllm_gen_mla.py
Fixes MLAStaticTileScheduler.get_current_work() non-persistent coordinate decode ordering; simplifies benchmark q_lens grid to [1,2,4,8,16] for all backends.

Sequence Diagram(s):

sequenceDiagram
  participant Caller as API Caller
  participant Wrapper as cute_dsl_mla_decode (batch)
  participant Monolithic as Monolithic Wrapper
  participant Kernel as FP16/FP8 Kernel
  Caller->>Wrapper: call with sinks, lse, return_lse, cute_dsl_impl
  Wrapper->>Monolithic: validate/reshape/provide lse (modular: raise NotImplemented)
  Monolithic->>Kernel: compile/run with num_heads, seq_len_q, fold_sq, lse buffer
  Kernel-->>Monolithic: compute outputs, write LSE (natural-log)
  Monolithic-->>Wrapper: return out or (out, lse)
  Wrapper-->>Caller: return to caller
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • aleozlx
  • cyx-6
  • bkryu
  • nv-yunzheq
  • sricketts
  • qsang-nv
  • saltyminty
  • jimmyzho
  • dhiraj113

Poem

"I hopped through folds and masked the night,
two warps in rhythm, softmax took flight.
LSEs turned natural, buffers kept their place,
kernels hummed softly, stride and grace.
A rabbit cheers these folded bits of pace."

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main feature: adding num_heads < 128 support for the MLA decode kernel, which is the primary objective across all file changes.
Description check ✅ Passed The description covers the main features (num_heads < 128 support with folding, LSE support), performance enhancements with detailed before/after metrics, related issues, completed checklists, and test validation.
Docstring Coverage ✅ Passed Docstring coverage is 84.48% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a dispatcher for the cute-dsl backend to select between modular and monolithic MLA decode implementations. It adds support for sinks in the cute-dsl backend via automatic promotion to the modular implementation and restores monolithic kernels to support fold_sq and spec-decoding causal masks. Feedback was provided regarding a leftover debug print statement in the monolithic implementation.

Comment thread flashinfer/cute_dsl/attention/monolithic/mla_decode.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/cute_dsl/attention/monolithic/mla_decode.py`:
- Line 456: The debug print statement in the MLA decoding hot path
(print(f"is_persistent: {is_persistent}")) should be removed to eliminate noisy
stdout and runtime overhead; locate the print in mla_decode.py (function/method
around mla_decode / the decode hot path) and delete it or replace it with a
conditional debug log using an existing logger (e.g., logger.debug) guarded by
log level checks so no per-call I/O occurs in production.
- Around line 396-399: The code assumes kv_cache becomes 3D but currently
squeezes dim 1 unconditionally when kv_cache.dim() == 4; update the check in the
mla_decode logic to validate the 4D layout before removing dim 1: only call
squeeze(1) if kv_cache.dim() == 4 and kv_cache.shape[1] == 1, otherwise raise a
clear ValueError (or reshape/transpose as appropriate) explaining the unexpected
kv_cache shape so downstream code that uses page_size = kv_cache.shape[1] and
other 3D assumptions fails fast and clearly; reference the kv_cache variable and
the block around the squeeze and page_size computation.

In `@flashinfer/cute_dsl/attention/wrappers/batch_mla.py`:
- Around line 838-854: The code currently only checks sinks.is_contiguous()
before creating AttentionWithSink, but you must also validate sinks.shape
matches the expected head count (e.g., 1-D and length == H) to avoid invalid
head indexing; before instantiating AttentionWithSink (and before calling
variant.extra_params.to(...)), assert sinks.dim() == 1 and sinks.size(0) == H
(or the variable holding the number of heads) and raise a ValueError that prints
sinks.shape and the expected H if the check fails, then proceed with creating
AttentionWithSink and converting variant.extra_params to torch.float32 on
query.device.

In `@flashinfer/mla/_core.py`:
- Around line 677-681: Update the docstring for the sinks parameter to reflect
actual backend support: state that sinks are supported by the cute-dsl backend
(requires modular implementation via cute_dsl_impl="auto" promoting to modular
or manual modular selection) but are rejected by the XQA execution path (which
raises a ValueError); ensure the wording mentions the sinks parameter name,
cute_dsl_impl, and the ValueError raised by the XQA path so the docstring aligns
with real behavior.

In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Line 1130: The regex pattern passed to pytest.raises in the
pytest.raises(ValueError, match="monolithic.*sinks.*modular") call should be a
raw string to avoid unintended escaping; update the match argument to use a raw
string literal (r"monolithic.*sinks.*modular") in the pytest.raises(...)
invocation so the regex metacharacters are interpreted correctly and the RUF043
warning is resolved.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 39130b48-8f1e-4bb4-9335-f7f71e99355f

📥 Commits

Reviewing files that changed from the base of the PR and between 103fcf8 and ebcec6c.

📒 Files selected for processing (13)
  • benchmarks/bench_trtllm_gen_mla.py
  • flashinfer/cute_dsl/attention/__init__.py
  • flashinfer/cute_dsl/attention/fusion/variant.py
  • flashinfer/cute_dsl/attention/mla_dispatch.py
  • flashinfer/cute_dsl/attention/monolithic/__init__.py
  • flashinfer/cute_dsl/attention/monolithic/mla_decode.py
  • flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py
  • flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py
  • flashinfer/cute_dsl/attention/monolithic/mla_helpers.py
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
  • flashinfer/mla/_core.py
  • tests/attention/test_cute_dsl_mla_decode.py
  • tests/attention/test_trtllm_gen_mla.py

Comment thread flashinfer/cute_dsl/attention/monolithic/mla_decode.py
Comment thread flashinfer/cute_dsl/attention/monolithic/mla_decode.py Outdated
Comment thread flashinfer/cute_dsl/attention/wrappers/batch_mla.py
Comment thread flashinfer/mla/_core.py
Comment thread tests/attention/test_cute_dsl_mla_decode.py Outdated
@nvpohanh
Copy link
Copy Markdown
Contributor

SGLang team is verifying the perf of this PR.

See sgl-project/sglang#24737

@b8zhong
Copy link
Copy Markdown
Contributor

b8zhong commented May 19, 2026

Hi, this new kernel seems to have a correctness issue with MTP = 4 (and bs = 1). However, for num heads 128 in non spec decode (for DSR1 TP = 4), it's good.
We don't think it's from integration, since I don't see any API change from trtllm backend.

Edit: the bug is actually due to shared workspace by accident from integration side, not the kernel. This PR LGTM.

@Observer007 Observer007 requested a review from dhiraj113 as a code owner May 21, 2026 01:30
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py (1)

103-120: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update the constraints block to match the new folded-head path.

This block still documents a fixed 128-head requirement, but the kernel now explicitly supports num_heads < 128 via folding/padding. Leaving the old invariant here will mislead the next person debugging can_implement() or the fold logic.

📝 Suggested doc update
 Constraints:
 * Data type requirements:
   - Input/output: Float16
   - Accumulation and LSE: Float32
 * Fixed architecture parameters:
-  - Number of attention heads: 128
   - Latent dimension: 512
   - RoPE dimension: 64
+* Head count may be smaller than the 128-row MMA-M tile; when enabled,
+  `seq_len_q` can be folded into the head dimension to populate the tile.
 * Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize)
 * Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py` around lines 103
- 120, The constraints comment in mla_decode_fp16.py still states a fixed
128-head requirement which is outdated; update that block to document that the
kernel supports folded/padded heads (num_heads <= 128) and may accept num_heads
< 128 via folding/padding, and clarify how RopeDim/LatentDim mapping and
sequence-length limits (1-4) interact with folding. Edit the constraints nearby
the class/definition for the monolithic FP16 kernel (the MLA monolithic
implementation) and mention the folding/padding behavior referenced by
can_implement() and the fold logic so readers know the kernel will pad/fold
heads rather than requiring exactly 128 heads.
flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py (1)

3754-3765: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Update the pipeline participant counts for the second compute warp group.

After adding second_compute_warp_ids, both softmax groups participate in mma_s_pipeline.consumer_wait() and produce into p_mma / p_cor, but these cooperative groups are still sized from self.compute_warp_ids only. That leaves half of the active compute warps outside the barrier accounting and can release stages early.

Suggested fix
         consumer_thread_size = (
             self.threads_per_warp
-            * len(self.compute_warp_ids)
+            * self.num_total_compute_warps
             * self.cluster_shape_mnk[0]
         )
@@
         producer_thread_size = (
             self.threads_per_warp
-            * len(self.compute_warp_ids)
+            * self.num_total_compute_warps
             * self.cluster_shape_mnk[0]
         )
@@
-        producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids)
+        producer_thread_size = self.threads_per_warp * self.num_total_compute_warps

Also applies to: 3789-3836

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py` around lines 3754
- 3765, The cooperative group sizes for the MMA softmax pipeline are still built
using only self.compute_warp_ids, which leaves the warps in
second_compute_warp_ids out of the barrier; update construction of
mma_s_producer_group and mma_s_consumer_group to account for both compute sets
(e.g., use len(self.compute_warp_ids) + len(self.second_compute_warp_ids) or the
combined list length) so that all threads participating in
mma_s_pipeline.consumer_wait(), all producers writing to p_mma/p_cor, and all
consumer threads are included in the CooperativeGroup counts (adjust any other
similar groups in the region around mma_s_* and the block at 3789-3836).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/cute_dsl/attention/monolithic/mla_decode.py`:
- Around line 465-479: The code currently accepts lse by shape/dtype only but
may misinterpret non-compact memory; update the lse handling to require
contiguous C-order buffers and reject non-compact inputs: when lse is provided,
check lse.is_contiguous() (and lse.dim() in {2,3}) and raise a ValueError if not
contiguous (or not 2D/3D), then proceed to assign lse_k (either use lse directly
when shape==(B,q_len,H) or view when shape==(B*q_len,H)); reference lse, lse_k,
and B/q_len/H and use .is_contiguous() in the validation to ensure callers pass
compact storage.

---

Outside diff comments:
In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py`:
- Around line 103-120: The constraints comment in mla_decode_fp16.py still
states a fixed 128-head requirement which is outdated; update that block to
document that the kernel supports folded/padded heads (num_heads <= 128) and may
accept num_heads < 128 via folding/padding, and clarify how RopeDim/LatentDim
mapping and sequence-length limits (1-4) interact with folding. Edit the
constraints nearby the class/definition for the monolithic FP16 kernel (the MLA
monolithic implementation) and mention the folding/padding behavior referenced
by can_implement() and the fold logic so readers know the kernel will pad/fold
heads rather than requiring exactly 128 heads.

In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py`:
- Around line 3754-3765: The cooperative group sizes for the MMA softmax
pipeline are still built using only self.compute_warp_ids, which leaves the
warps in second_compute_warp_ids out of the barrier; update construction of
mma_s_producer_group and mma_s_consumer_group to account for both compute sets
(e.g., use len(self.compute_warp_ids) + len(self.second_compute_warp_ids) or the
combined list length) so that all threads participating in
mma_s_pipeline.consumer_wait(), all producers writing to p_mma/p_cor, and all
consumer threads are included in the CooperativeGroup counts (adjust any other
similar groups in the region around mma_s_* and the block at 3789-3836).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 830ef233-a6a0-4228-945d-132638b8a222

📥 Commits

Reviewing files that changed from the base of the PR and between f285855 and 39070d7.

📒 Files selected for processing (8)
  • benchmarks/bench_trtllm_gen_mla.py
  • flashinfer/cute_dsl/attention/monolithic/mla_decode.py
  • flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py
  • flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py
  • flashinfer/cute_dsl/attention/monolithic/mla_helpers.py
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
  • flashinfer/mla/_core.py
  • tests/attention/test_cute_dsl_mla_decode.py
💤 Files with no reviewable changes (3)
  • flashinfer/mla/_core.py
  • tests/attention/test_cute_dsl_mla_decode.py
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py

Comment thread flashinfer/cute_dsl/attention/monolithic/mla_decode.py
@Observer007
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@Observer007 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@nvpohanh
Copy link
Copy Markdown
Contributor

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !709 has been created, and the CI pipeline #52588718 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants