Skip to content

Support Mamba KV cache prefix caching for Ascend NPU#5

Closed
lHrHenry233 wants to merge 1 commit intomainfrom
feat/mamba-prefix-caching-align-full
Closed

Support Mamba KV cache prefix caching for Ascend NPU#5
lHrHenry233 wants to merge 1 commit intomainfrom
feat/mamba-prefix-caching-align-full

Conversation

@lHrHenry233
Copy link
Copy Markdown
Owner

This pull request introduces support for Mamba cache alignment and optimizes memory copy operations for Ascend NPU in the vLLM-Ascend worker. The main changes include adding a Triton-based batch memory copy utility for Mamba, updating the model runner to handle Mamba-specific preprocessing, and enhancing input batch initialization to better support Mamba cache configurations.

Mamba cache alignment and memory copy enhancements

  • Added a new module patch_mamba_utils.py that implements a Triton kernel for batch memory copy (batch_memcpy_kernel) and exposes batch_memcpy for efficient memory operations on Ascend NPU, replacing the CUDA implementation.
  • Registered the new patch_mamba_utils module in the worker initialization to ensure the Mamba utilities are available. (__init__.py)

Model runner updates for Mamba support

  • Integrated Mamba preprocessing logic in execute_model, calling preprocess_mamba when cache alignment is enabled, and added state tracking for Mamba in the model runner. (model_runner_v1.py) [1] [2]
  • Reset Mamba copy buffers during KV cache initialization to ensure correct state management. (model_runner_v1.py)

Input batch initialization improvements

  • Enhanced input batch initialization to dynamically calculate and provide max_num_blocks_per_req, especially for Mamba cache groups, improving resource allocation and handling speculative blocks. (model_runner_v1.py, npu_input_batch.py) [1] [2] [3] [4] [5]This pull request introduces support for Mamba cache alignment and improves handling of Mamba-related state and buffer management in the Ascend NPU backend. The main changes include adding a new patch for Mamba utilities, updating the model runner to use these utilities, and enhancing input batch initialization to better accommodate Mamba cache requirements.

Mamba cache and state management enhancements:

  • Added new module patch_mamba_utils.py implementing Triton-based batch memory copy for Ascend NPU, replacing CUDA kernel and exposing batch_memcpy and batch_memcpy_kernel functions. (vllm_ascend/patch/worker/patch_mamba_utils.py)
  • Imported patch_mamba_utils in the worker patch initialization to ensure Mamba utilities are available. (vllm_ascend/patch/worker/__init__.py)
  • Integrated Mamba cache preprocessing in execute_model, calling mamba_utils.preprocess_mamba when cache mode is "align". (vllm_ascend/worker/model_runner_v1.py)
  • Added Mamba-specific state index and buffer attributes to ModelRunnerV1, and ensured Mamba copy buffers are reset when initializing KV cache. (vllm_ascend/worker/model_runner_v1.py) [1] [2]

Input batch initialization improvements:

  • Modified input batch reinitialization logic to calculate and pass max_num_blocks_per_req, considering Mamba cache requirements and speculative blocks, and updated NPUInputBatch to accept this parameter. (vllm_ascend/worker/model_runner_v1.py, vllm_ascend/worker/npu_input_batch.py) [1] [2] [3] [4] [5]

These changes collectively improve support for Mamba cache alignment and ensure proper buffer and state management for models using Mamba on Ascend NPU.…cend NPU

  • Add patch_mamba_utils.py: Triton-based batch_memcpy kernel for NPU (replaces upstream CUDA kernel in vllm.v1.worker.mamba_utils)
  • Register patch in patch/worker/init.py
  • Add mamba_state_idx and _mamba_copy_bufs to NPUModelRunner.init
  • Reset _mamba_copy_bufs in initialize_kv_cache
  • Call mamba_utils.preprocess_mamba() in execute_model when mamba_cache_mode == 'align' (following vllm PR #30877)
  • Compute per-group max_num_blocks (honouring Mamba prefix caching) in may_reinitialize_input_batch and pass to NPUInputBatch
  • Add max_num_blocks_per_req param to NPUInputBatch and forward it to MultiGroupBlockTable

Enable with: --enable-prefix-caching --mamba-cache-mode align

(cherry picked from commit b2f6257)

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

…cend NPU

- Add patch_mamba_utils.py: Triton-based batch_memcpy kernel for NPU
  (replaces upstream CUDA kernel in vllm.v1.worker.mamba_utils)
- Register patch in patch/worker/__init__.py
- Add mamba_state_idx and _mamba_copy_bufs to NPUModelRunner.__init__
- Reset _mamba_copy_bufs in initialize_kv_cache
- Call mamba_utils.preprocess_mamba() in execute_model when
  mamba_cache_mode == 'align' (following vllm PR #30877)
- Compute per-group max_num_blocks (honouring Mamba prefix caching)
  in may_reinitialize_input_batch and pass to NPUInputBatch
- Add max_num_blocks_per_req param to NPUInputBatch and forward it
  to MultiGroupBlockTable

Enable with: --enable-prefix-caching --mamba-cache-mode align

Signed-off-by: Copilot <copilot@github.com>

Co-authored-by: lHrHenry233 <99803155+lHrHenry233@users.noreply.github.com>
(cherry picked from commit b2f6257)
@lHrHenry233 lHrHenry233 closed this Apr 1, 2026
@lHrHenry233 lHrHenry233 deleted the feat/mamba-prefix-caching-align-full branch April 1, 2026 06:42
@lHrHenry233 lHrHenry233 restored the feat/mamba-prefix-caching-align-full branch April 1, 2026 06:43
@lHrHenry233 lHrHenry233 deleted the feat/mamba-prefix-caching-align-full branch April 1, 2026 07:20
lHrHenry233 pushed a commit that referenced this pull request Apr 10, 2026
…(v3.1)

- Port upstream _causal_conv1d_fwd_kernel as NPU Triton kernel
  - Handles initial/final/intermediate conv state in-kernel
  - Supports APC block boundary state writes
  - NPU adaptations: removed .cache_modifier, kept debug_barrier
- Rewrite causal_conv1d_fn to dispatch to new Triton kernel
- Rewrite gdn.py conv1d path: split decode/prefill like upstream
  - Decode: causal_conv1d_update_npu with block params
  - Prefill: causal_conv1d_fn with APC params (new kernel)
- Fix SSM #6: _build_initial_state only zeros prefill sequences
- Fix SSM #7: _write_final_states adds slot >= 0 validation
- Fix SSM #8: _scatter_intermediate_states adds unaligned offset
- Update all 36 UTs to pass with new num_computed_tokens_all field

Alignment status vs upstream #26807:
  #1 conv1d prefill kernel:     FIXED (kernel ported)
  #3 causal_conv1d_fn params:   FIXED (rewritten)
  #4 intermediate conv state:   FIXED (kernel internal)
  #6 SSM zeroing scope:         FIXED
  #7 _write_final_states guard: FIXED
  #8 SSM scatter alignment:     FIXED
  #9 causal_conv1d_fn signature: FIXED
  #2 decode pre-copy:           KEEP (NPU needs it)
  #5 SSM decode index:          OK (correct approach)
  #10 conv layout hardcoded:    DEFERRED

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants