[None][feat] fix mamba metadata prefill bubble in chunked prefill serving#12736
Conversation
📝 WalkthroughWalkthroughModified Mamba2 metadata processing to reduce GPU-to-CPU data reads by extending the Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
297-311: Correct CPU-side computation eliminates GPU sync.The loop correctly replicates the cumulative-sum-based misalignment counting from the GPU version. Verified the logic is equivalent: both compute cumulative sums at sequence boundaries and count non-aligned positions.
This relies on the documented invariant that
attn_metadata.seq_lensis CPU-resident (perAttentionMetadatainterface). Consider adding a defensive assertion if this invariant is ever violated, as a GPU-resident tensor would silently reintroduce the sync.,
🛡️ Optional: Defensive assertion
if self.use_initial_states: # Compute extra_chunks using pure Python arithmetic on CPU # seq_lens to avoid any GPU->CPU sync point. + assert not attn_metadata.seq_lens.is_cuda, \ + "seq_lens must be CPU-resident for sync-free extra_chunks computation" _cs = self.chunk_size _cumsum = 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py` around lines 297 - 311, The CPU-side loop computing _extra assumes attn_metadata.seq_lens is CPU-resident; add a defensive assertion before the loop that attn_metadata.seq_lens is on the CPU (e.g., check tensor.device.type or .is_cuda) and raise/handle if not, so we avoid silently triggering a GPU->CPU sync; keep the subsequent logic and the call to cu_seqlens_to_chunk_indices_offsets_triton (which sets self.chunk_indices and self.chunk_offsets) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py`:
- Around line 297-311: The CPU-side loop computing _extra assumes
attn_metadata.seq_lens is CPU-resident; add a defensive assertion before the
loop that attn_metadata.seq_lens is on the CPU (e.g., check tensor.device.type
or .is_cuda) and raise/handle if not, so we avoid silently triggering a GPU->CPU
sync; keep the subsequent logic and the call to
cu_seqlens_to_chunk_indices_offsets_triton (which sets self.chunk_indices and
self.chunk_offsets) unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 9c833e0a-4677-409a-85a3-6a7e81f1b43f
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py
Three optimizations to eliminate GPU idle bubbles during prefill in
Mamba2Metadata.prepare() for hybrid GDN models (e.g. Qwen3.5):
1. Remove tl.constexpr from num_seqs and N in _cu_seqlens_triton_kernel.
Triton JIT recompiles for each unique constexpr value (~120ms each).
In serving, num_seqs varies every prefill step, causing repeated
recompilation. With dynamic parameters, only one compilation occurs.
2. Accept total_seqlens from caller to skip first GPU->CPU sync.
cu_seqlens[-1].item() blocked on all pending GPU work. The caller
(Mamba2Metadata.prepare) already has num_ctx_tokens on CPU.
3. Compute extra_chunks with pure Python arithmetic on CPU seq_lens
to eliminate the second GPU->CPU sync (cumsum + p[-1].item()).
Before: _prepare_inputs ~120-460ms per prefill step (Triton recompile +
GPU sync bubbles)
After: _prepare_inputs ~1-2ms steady state
Verified: 9200+ random equivalence tests + e2e serving assertion with
1000 requests (0 mismatches). GSM8K accuracy unchanged (90.07% on full
1319 samples).
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
de232de to
58fbd68
Compare
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #41695 [ run ] triggered by Bot. Commit: |
|
PR_Github #41695 [ run ] completed with state
|
|
/bot run |
|
PR_Github #41774 [ run ] triggered by Bot. Commit: |
|
PR_Github #41774 [ run ] completed with state |
…ving (NVIDIA#12736) Signed-off-by: Shijie Wang <jaywan@nvidia.com>
…ving (NVIDIA#12736) Signed-off-by: Shijie Wang <jaywan@nvidia.com>
…ving (NVIDIA#12736) Signed-off-by: Shijie Wang <jaywan@nvidia.com>
…ving (NVIDIA#12736) Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Three optimizations to eliminate GPU idle bubbles during prefill in Mamba2Metadata.prepare() for hybrid GDN models (e.g. Qwen3.5):
Remove tl.constexpr from num_seqs and N in _cu_seqlens_triton_kernel. Triton JIT recompiles for each unique constexpr value (~120ms each). In serving, num_seqs varies every prefill step, causing repeated recompilation. With dynamic parameters, only one compilation occurs.
Accept total_seqlens from caller to skip first GPU->CPU sync. cu_seqlens[-1].item() blocked on all pending GPU work. The caller (Mamba2Metadata.prepare) already has num_ctx_tokens on CPU.
Compute extra_chunks with pure Python arithmetic on CPU seq_lens to eliminate the second GPU->CPU sync (cumsum + p[-1].item()).
Before: _prepare_inputs ~120-460ms per prefill step (Triton recompile +
GPU sync bubbles)
After: _prepare_inputs ~1-2ms steady state
Verified: 9200+ random equivalence tests + e2e serving assertion with 1000 requests (0 mismatches). GSM8K accuracy unchanged (90.07% on full 1319 samples).
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.