Skip to content

[Feat][Mamba] Support Mamba KV cache prefix caching (align mode) on Ascend NPU#1

Closed
Copilot wants to merge 2 commits intomainfrom
copilot/add-mamba-kv-cache-prefix-caching
Closed

[Feat][Mamba] Support Mamba KV cache prefix caching (align mode) on Ascend NPU#1
Copilot wants to merge 2 commits intomainfrom
copilot/add-mamba-kv-cache-prefix-caching

Conversation

Copy link
Copy Markdown

Copilot AI commented Mar 10, 2026

Upstream vLLM PR #30877 merged Mamba prefix caching (align mode) into vLLM v0.16.0, but vllm-ascend had no NPU adaptation. This PR ports the work from Angazenn/vllm-ascend@mamba_apc to enable the feature on Ascend hardware.

Changes

  • vllm_ascend/patch/worker/patch_mamba_utils.py (new)
    Replaces the upstream CUDA batch_memcpy kernel used by vllm.v1.worker.mamba_utils.preprocess_mamba with a Triton kernel compatible with Ascend NPU. Uses BLOCK_SIZE=128 for vectorized byte-level copies.

  • vllm_ascend/patch/worker/__init__.py
    Registers patch_mamba_utils so the kernel replacement is applied at startup.

  • vllm_ascend/worker/model_runner_v1.py

    • Imports mamba_utils and get_total_cp_world_size.
    • Adds mamba_state_idx and _mamba_copy_bufs to __init__; resets _mamba_copy_bufs in initialize_kv_cache.
    • Calls mamba_utils.preprocess_mamba() in execute_model when mamba_cache_mode == "align". (postprocess_mamba is inherited via _update_states_after_model_execute without override.)
    • Computes correct per-group max_num_blocks in may_reinitialize_input_batch, accounting for whether prefix caching is enabled for Mamba groups.
  • vllm_ascend/worker/npu_input_batch.py
    Adds max_num_blocks_per_req: list[int] | None = None and forwards it to MultiGroupBlockTable.

Usage

vllm serve <model> --enable-prefix-caching --mamba-cache-mode align

Supports all hybrid Mamba model architectures (Mamba1/2, GDN, Short Conv Attention) including Qwen3-Next and LFM2.


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

…cend NPU

- Add patch_mamba_utils.py: Triton-based batch_memcpy kernel for NPU
  (replaces upstream CUDA kernel in vllm.v1.worker.mamba_utils)
- Register patch in patch/worker/__init__.py
- Add mamba_state_idx and _mamba_copy_bufs to NPUModelRunner.__init__
- Reset _mamba_copy_bufs in initialize_kv_cache
- Call mamba_utils.preprocess_mamba() in execute_model when
  mamba_cache_mode == 'align' (following vllm PR #30877)
- Compute per-group max_num_blocks (honouring Mamba prefix caching)
  in may_reinitialize_input_batch and pass to NPUInputBatch
- Add max_num_blocks_per_req param to NPUInputBatch and forward it
  to MultiGroupBlockTable

Enable with: --enable-prefix-caching --mamba-cache-mode align

Signed-off-by: Copilot <copilot@github.com>

Co-authored-by: lHrHenry233 <99803155+lHrHenry233@users.noreply.github.com>
Copilot AI changed the title [WIP] Add mamba kv cache prefix caching support [Feat][Mamba] Support Mamba KV cache prefix caching (align mode) on Ascend NPU Mar 10, 2026
Copilot stopped work on behalf of lHrHenry233 due to an error March 10, 2026 07:45
@lHrHenry233 lHrHenry233 deleted the copilot/add-mamba-kv-cache-prefix-caching branch March 10, 2026 09:01
lHrHenry233 pushed a commit that referenced this pull request Apr 10, 2026
…(v3.1)

- Port upstream _causal_conv1d_fwd_kernel as NPU Triton kernel
  - Handles initial/final/intermediate conv state in-kernel
  - Supports APC block boundary state writes
  - NPU adaptations: removed .cache_modifier, kept debug_barrier
- Rewrite causal_conv1d_fn to dispatch to new Triton kernel
- Rewrite gdn.py conv1d path: split decode/prefill like upstream
  - Decode: causal_conv1d_update_npu with block params
  - Prefill: causal_conv1d_fn with APC params (new kernel)
- Fix SSM #6: _build_initial_state only zeros prefill sequences
- Fix SSM #7: _write_final_states adds slot >= 0 validation
- Fix SSM #8: _scatter_intermediate_states adds unaligned offset
- Update all 36 UTs to pass with new num_computed_tokens_all field

Alignment status vs upstream #26807:
  #1 conv1d prefill kernel:     FIXED (kernel ported)
  #3 causal_conv1d_fn params:   FIXED (rewritten)
  #4 intermediate conv state:   FIXED (kernel internal)
  #6 SSM zeroing scope:         FIXED
  #7 _write_final_states guard: FIXED
  #8 SSM scatter alignment:     FIXED
  #9 causal_conv1d_fn signature: FIXED
  #2 decode pre-copy:           KEEP (NPU needs it)
  #5 SSM decode index:          OK (correct approach)
  #10 conv layout hardcoded:    DEFERRED

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants