Skip to content

[None][feat] fix mamba metadata prefill bubble in chunked prefill serving#12736

Merged
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/mamba-prefill-bubble-fix
Apr 7, 2026
Merged

[None][feat] fix mamba metadata prefill bubble in chunked prefill serving#12736
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/mamba-prefill-bubble-fix

Conversation

@nv-guomingz
Copy link
Copy Markdown
Collaborator

@nv-guomingz nv-guomingz commented Apr 3, 2026

Three optimizations to eliminate GPU idle bubbles during prefill in Mamba2Metadata.prepare() for hybrid GDN models (e.g. Qwen3.5):

  1. Remove tl.constexpr from num_seqs and N in _cu_seqlens_triton_kernel. Triton JIT recompiles for each unique constexpr value (~120ms each). In serving, num_seqs varies every prefill step, causing repeated recompilation. With dynamic parameters, only one compilation occurs.

  2. Accept total_seqlens from caller to skip first GPU->CPU sync. cu_seqlens[-1].item() blocked on all pending GPU work. The caller (Mamba2Metadata.prepare) already has num_ctx_tokens on CPU.

  3. Compute extra_chunks with pure Python arithmetic on CPU seq_lens to eliminate the second GPU->CPU sync (cumsum + p[-1].item()).

Before: _prepare_inputs ~120-460ms per prefill step (Triton recompile +
GPU sync bubbles)
After: _prepare_inputs ~1-2ms steady state

Verified: 9200+ random equivalence tests + e2e serving assertion with 1000 requests (0 mismatches). GSM8K accuracy unchanged (90.07% on full 1319 samples).

Summary by CodeRabbit

  • Refactor
    • Updated Mamba2 metadata parameter handling to support optional computation parameters, improving flexibility in sequence length calculations.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

📝 Walkthrough

Walkthrough

Modified Mamba2 metadata processing to reduce GPU-to-CPU data reads by extending the cu_seqlens_to_chunk_indices_offsets_triton function signature with optional parameters (total_seqlens, extra_chunks) and computing extra_chunks on the CPU side using a Python loop instead of GPU-derived values.

Changes

Cohort / File(s) Summary
Mamba2 Metadata Optimization
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py
Removed tl.constexpr typing from Triton kernel parameters (num_seqs, N). Extended cu_seqlens_to_chunk_indices_offsets_triton with optional total_seqlens and extra_chunks parameters (default -1). Modified conditional logic to only compute values from GPU when not provided. Replaced GPU-derived extra_chunks computation with CPU-side Python loop in Mamba2Metadata.prepare when using initial states.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: fixing mamba metadata prefill bubble in chunked prefill serving.
Description check ✅ Passed The PR description clearly explains the issue, the three optimization solutions, performance metrics, and comprehensive verification approach.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)

297-311: Correct CPU-side computation eliminates GPU sync.

The loop correctly replicates the cumulative-sum-based misalignment counting from the GPU version. Verified the logic is equivalent: both compute cumulative sums at sequence boundaries and count non-aligned positions.

This relies on the documented invariant that attn_metadata.seq_lens is CPU-resident (per AttentionMetadata interface). Consider adding a defensive assertion if this invariant is ever violated, as a GPU-resident tensor would silently reintroduce the sync.

,

🛡️ Optional: Defensive assertion
             if self.use_initial_states:
                 # Compute extra_chunks using pure Python arithmetic on CPU
                 # seq_lens to avoid any GPU->CPU sync point.
+                assert not attn_metadata.seq_lens.is_cuda, \
+                    "seq_lens must be CPU-resident for sync-free extra_chunks computation"
                 _cs = self.chunk_size
                 _cumsum = 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py` around lines 297 - 311,
The CPU-side loop computing _extra assumes attn_metadata.seq_lens is
CPU-resident; add a defensive assertion before the loop that
attn_metadata.seq_lens is on the CPU (e.g., check tensor.device.type or
.is_cuda) and raise/handle if not, so we avoid silently triggering a GPU->CPU
sync; keep the subsequent logic and the call to
cu_seqlens_to_chunk_indices_offsets_triton (which sets self.chunk_indices and
self.chunk_offsets) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py`:
- Around line 297-311: The CPU-side loop computing _extra assumes
attn_metadata.seq_lens is CPU-resident; add a defensive assertion before the
loop that attn_metadata.seq_lens is on the CPU (e.g., check tensor.device.type
or .is_cuda) and raise/handle if not, so we avoid silently triggering a GPU->CPU
sync; keep the subsequent logic and the call to
cu_seqlens_to_chunk_indices_offsets_triton (which sets self.chunk_indices and
self.chunk_offsets) unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9c833e0a-4677-409a-85a3-6a7e81f1b43f

📥 Commits

Reviewing files that changed from the base of the PR and between 1045f38 and de232de.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py

@nv-guomingz nv-guomingz changed the title fix mamba metadata prefill bubble in chunked prefill serving [None][feat] fix mamba metadata prefill bubble in chunked prefill serving Apr 3, 2026
Three optimizations to eliminate GPU idle bubbles during prefill in
Mamba2Metadata.prepare() for hybrid GDN models (e.g. Qwen3.5):

1. Remove tl.constexpr from num_seqs and N in _cu_seqlens_triton_kernel.
   Triton JIT recompiles for each unique constexpr value (~120ms each).
   In serving, num_seqs varies every prefill step, causing repeated
   recompilation. With dynamic parameters, only one compilation occurs.

2. Accept total_seqlens from caller to skip first GPU->CPU sync.
   cu_seqlens[-1].item() blocked on all pending GPU work. The caller
   (Mamba2Metadata.prepare) already has num_ctx_tokens on CPU.

3. Compute extra_chunks with pure Python arithmetic on CPU seq_lens
   to eliminate the second GPU->CPU sync (cumsum + p[-1].item()).

Before: _prepare_inputs ~120-460ms per prefill step (Triton recompile +
        GPU sync bubbles)
After:  _prepare_inputs ~1-2ms steady state

Verified: 9200+ random equivalence tests + e2e serving assertion with
1000 requests (0 mismatches). GSM8K accuracy unchanged (90.07% on full
1319 samples).

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
@nv-guomingz nv-guomingz force-pushed the user/guomingz/mamba-prefill-bubble-fix branch from de232de to 58fbd68 Compare April 3, 2026 15:28
@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41695 [ run ] triggered by Bot. Commit: 58fbd68 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41695 [ run ] completed with state SUCCESS. Commit: 58fbd68
/LLM/main/L0_MergeRequest_PR pipeline #32598 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41774 [ run ] triggered by Bot. Commit: 58fbd68 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41774 [ run ] completed with state SUCCESS. Commit: 58fbd68
/LLM/main/L0_MergeRequest_PR pipeline #32668 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Copy Markdown
Collaborator

@rosenrodt rosenrodt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. cc the original author @Wong4j

Copy link
Copy Markdown
Collaborator

@Wanli-Jiang Wanli-Jiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good optimization.

@nv-guomingz nv-guomingz merged commit 4496e69 into NVIDIA:main Apr 7, 2026
7 checks passed
xinhe-nv pushed a commit to xinhe-nv/TensorRT-LLM that referenced this pull request Apr 7, 2026
…ving (NVIDIA#12736)

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request Apr 7, 2026
…ving (NVIDIA#12736)

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
karen-sy pushed a commit to karen-sy/TensorRT-LLM that referenced this pull request Apr 7, 2026
…ving (NVIDIA#12736)

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
suyoggupta pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Apr 8, 2026
…ving (NVIDIA#12736)

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants