Skip to content

[Bugfix] Fix mamba cache mode null-block padding#33937

Closed
tianshu-Michael-yu wants to merge 3 commits into
vllm-project:mainfrom
tianshu-Michael-yu:fix/mamba-cache-mode-null-block-pad-dco
Closed

[Bugfix] Fix mamba cache mode null-block padding#33937
tianshu-Michael-yu wants to merge 3 commits into
vllm-project:mainfrom
tianshu-Michael-yu:fix/mamba-cache-mode-null-block-pad-dco

Conversation

@tianshu-Michael-yu

Copy link
Copy Markdown
Contributor

Purpose

In mamba_get_block_table_tensor, block id 0 is reserved for BlockPool.null_block (never allocated) but can appear in block tables as a placeholder (e.g. mamba align mode). Mamba kernels treat PAD_SLOT_ID (-1) as padding; if we pass block id 0 through, kernels can read/write state for the shared null block, causing cross-request state corruption.

This PR maps block id 0 to PAD_SLOT_ID for all mamba cache modes.

Test Plan

  • python -m pytest tests/v1/attention/test_mamba_block_table_tensor.py

Test Result

  • 2 passed (CPU)

Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
@mergify mergify Bot added v1 bug Something isn't working labels Feb 5, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a critical bugfix for Mamba cache handling. The change correctly maps the null block ID (0) to the padding slot ID (-1) in mamba_get_block_table_tensor for all cache modes. This prevents potential cross-request state corruption in Mamba kernels, which could occur if they were to read from or write to the shared null block. The fix is implemented cleanly using torch.where to avoid in-place modification of the block table, which is a good safety measure. The newly added tests in tests/v1/attention/test_mamba_block_table_tensor.py are comprehensive, covering align, all, and none cache modes, and effectively validate the correctness of the fix. The changes are well-reasoned and improve the robustness of the Mamba implementation.

@mergify

mergify Bot commented Feb 5, 2026

Copy link
Copy Markdown
Contributor

Hi @tianshu-Michael-yu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
@tdoublep

tdoublep commented Feb 6, 2026

Copy link
Copy Markdown
Member

if we pass block id 0 through, kernels can read/write state for the shared null block, causing cross-request state corruption.

Do we have a reproducible bug where this is actually happening?

@tianshu-Michael-yu

tianshu-Michael-yu commented Feb 6, 2026

Copy link
Copy Markdown
Contributor Author

Yes — we were able to reproduce this in real inference workloads (before any training update), not just by code inspection.

What we saw:
Model: LFM2 (Mamba/hybrid) with prefix caching enabled.

  • Config path: hybrid/Mamba falls back to mamba_cache_mode=align.
  • Symptom: occasional corrupted generations (including long runs of U+FFFD / nonsensical text) on step 1.
  • If we disable prefix caching (or avoid the align path), corruption disappears.
    Why this happens in practice:
  • The block table can contain 0 for padded entries (shared null_block).
  • Mamba kernels interpret padding as -1 (PAD_SLOT_ID).
  • Passing 0 through lets kernels index the shared null block state, which can leak/corrupt state across requests.

After mapping 0 -> -1 in mamba_get_block_table_tensor, we no longer observed this corruption in repeated rollout/inference runs under the same settings.

@peakcrosser7

Copy link
Copy Markdown
Contributor

Hi @tianshu-Michael-yu, thanks for the PR! I’d like to ask for a few more details to better understand the situation. Specifically, was spec decoding enabled for the model you were using?
Without spec decoding, each request in align mode should only return one block id for the current Mamba state, and no null-blocks should be present.
Even if spec decoding is enabled (where it returns 1 + num_speculative_blocks), we still shouldn't see any null-blocks. In the bad case you encountered, did you notice where the null-block was located? For instance, was it at the beginning or the end of the returned tensor? I’m concerned that a bug in another part of the logic might be causing these null-blocks to appear.

@tianshu-Michael-yu

tianshu-Michael-yu commented Feb 7, 2026

Copy link
Copy Markdown
Contributor Author

Great question. Here are the concrete details from our repro path.

  • Spec decoding was disabled (speculative_config=None). This was the standard async rollout/inference path (no draft model / no speculative tokens configured).
  • In this setup (align + no spec decode), mamba_get_block_table_tensor returns shape (num_reqs, 1), so there is only one gathered entry per request. In the bad cases, that single gathered entry was 0 (null block id).

Why 0 can still legitimately appear in align mode (even without spec decode):

  • MambaManager.find_longest_cache_hit can insert leading null placeholders for skipped blocks:
  • computed.extend([block_pool.null_block] * i)
  • MambaManager.allocate_new_blocks can set prior positions back to null when reusing blocks:
  • req_blocks[block_idx] = self._null_block
  • MambaManager.remove_skipped_blocks can also replace older blocks with null:
  • blocks[last_state_block_idx] = self._null_block

So null-blocks are not only a tail-padding artifact; they can be in internal/earlier positions of the per-request block table in align mode. If the gathered index lands there, the output seen by Mamba kernels is 0 unless we sanitize to PAD_SLOT_ID (-1).

@tianshu-Michael-yu

tianshu-Michael-yu commented Feb 7, 2026

Copy link
Copy Markdown
Contributor Author

I prepared a minimal runnable repro + log trace that shows exactly where the null-block appears.

Repro script (spec decode disabled)

#!/usr/bin/env python3
import torch

from vllm.v1.attention.backends.utils import PAD_SLOT_ID, mamba_get_block_table_tensor
from vllm.v1.kv_cache_interface import MambaSpec

spec = MambaSpec(block_size=4, shapes=((1,),), dtypes=(torch.float16,))
spec_decode_enabled = spec.num_speculative_blocks > 0

# null block id 0 is in the middle (index 1) for req0
block_table = torch.tensor(
    [
        [11, 0, 15, 19],
        [21, 22, 23, 24],
    ],
    dtype=torch.int32,
)
seq_lens = torch.tensor([5, 13], dtype=torch.int32)

start_indices = torch.clamp((seq_lens - 1) // spec.block_size, min=0)
indices_to_gather = start_indices.unsqueeze(1)
raw = torch.gather(block_table, 1, indices_to_gather)

current = mamba_get_block_table_tensor(block_table, seq_lens, spec, "align")
expected_fixed = torch.where(raw == 0, raw.new_full((), PAD_SLOT_ID), raw)

print("=== Repro: Mamba align-mode null block gather ===")
print(f"spec_decode_enabled: {spec_decode_enabled}")
print(f"num_speculative_blocks: {spec.num_speculative_blocks}")
print(f"PAD_SLOT_ID: {PAD_SLOT_ID}")
print(f"block_table:\n{block_table}")
print(f"seq_lens: {seq_lens.tolist()}")
print(f"start_indices: {start_indices.tolist()}")
print(f"indices_to_gather:\n{indices_to_gather}")
print(f"raw_gather_before_sanitize:\n{raw}")
print(f"mamba_get_block_table_tensor(current):\n{current}")
print(f"expected_if_fixed(0->-1):\n{expected_fixed}")

for i in range(block_table.shape[0]):
    idx = int(start_indices[i])
    raw_val = int(raw[i, 0])
    cur_val = int(current[i, 0])
    exp_val = int(expected_fixed[i, 0])
    print(
        f"req={i}: gathered block_table[{i},{idx}]={raw_val} "
        f"(null={raw_val==0}) current={cur_val} expected_fixed={exp_val}"
    )

Run command

python repro_mamba_align_null_block.py | tee repro_mamba_align_null_block.log

Captured log trace (from my run)

=== Repro: Mamba align-mode null block gather ===
spec_decode_enabled: False
num_speculative_blocks: 0
PAD_SLOT_ID: -1
block_table:
tensor([[11,  0, 15, 19],
        [21, 22, 23, 24]], dtype=torch.int32)
seq_lens: [5, 13]
start_indices: [1, 3]
indices_to_gather:
tensor([[1],
        [3]], dtype=torch.int32)
raw_gather_before_sanitize:
tensor([[ 0],
        [24]], dtype=torch.int32)
mamba_get_block_table_tensor(current):
tensor([[ 0],
        [24]], dtype=torch.int32)
expected_if_fixed(0->-1):
tensor([[-1],
        [24]], dtype=torch.int32)
req=0: gathered block_table[0,1]=0 (null=True) current=0 expected_fixed=-1
req=1: gathered block_table[1,3]=24 (null=False) current=24 expected_fixed=24

This shows the null-block is at an internal position (index 1 here), and with spec decode disabled we can still gather that 0 in align mode.

@tianshu-Michael-yu

tianshu-Michael-yu commented Feb 7, 2026

Copy link
Copy Markdown
Contributor Author

@peakcrosser7 Thanks for the push on this. You’re right that the best evidence should come from a standalone vLLM serve path.

I re-ran this with a standalone OpenAI server flow, using public vLLM source + LFM2 model.

Standalone repro (serve path)

git clone https://github.com/vllm-project/vllm.git
cd vllm
# checkout this PR branch/commit under test

uv venv .venv
source .venv/bin/activate
VLLM_USE_PRECOMPILED=1 uv pip install -e .

cat > sitecustomize.py <<'PY'
import logging, os, torch
if os.getenv("VLLM_DEBUG_MAMBA_RAW_NULL", "0") == "1":
    import vllm.v1.attention.backends.utils as u
    orig = u.mamba_get_block_table_tensor
    logger = logging.getLogger("vllm.mamba.rawnull")
    def wrapped(block_table, seq_lens, kv_cache_spec, mamba_cache_mode):
        if mamba_cache_mode == "align":
            start_indices = torch.clamp((seq_lens - 1) // kv_cache_spec.block_size, min=0)
            offsets = torch.arange(1 + kv_cache_spec.num_speculative_blocks, device=block_table.device)
            raw = torch.gather(block_table, 1, start_indices.unsqueeze(1) + offsets)
            null_count = int((raw == 0).sum().item())
            if null_count > 0:
                logger.warning(
                    "MAMBA_ALIGN_RAW_NULL_DETECTED null_count=%d raw_shape=%s num_speculative_blocks=%d start_indices_head=%s",
                    null_count, tuple(raw.shape), int(kv_cache_spec.num_speculative_blocks),
                    start_indices[:8].cpu().tolist(),
                )
        return orig(block_table, seq_lens, kv_cache_spec, mamba_cache_mode)
    u.mamba_get_block_table_tensor = wrapped
PY

export PYTHONPATH=$PWD:$PYTHONPATH
export VLLM_DEBUG_MAMBA_RAW_NULL=1
python -m vllm.entrypoints.openai.api_server \
  --model LiquidAI/LFM2-1.2B \
  --enable-prefix-caching \
  --mamba-cache-mode align \
  --tensor-parallel-size 1 \
  --dtype bfloat16 \
  --max-model-len 4096 \
  --gpu-memory-utilization 0.7

Then send normal completion requests to http://127.0.0.1:8000/v1/completions (no spec decode config).

What the server log shows (from my run)

... speculative_config=None ... enable_prefix_caching=True ...
... MAMBA_ALIGN_RAW_NULL_DETECTED null_count=512 raw_shape=(512, 1) num_speculative_blocks=0 start_indices_head=[0, 0, 0, 0, 0, 0, 0, 0]
... MAMBA_ALIGN_RAW_NULL_DETECTED null_count=496 raw_shape=(496, 1) num_speculative_blocks=0 start_indices_head=[0, 0, 0, 0, 0, 0, 0, 0]
... MAMBA_ALIGN_RAW_NULL_DETECTED null_count=480 raw_shape=(480, 1) num_speculative_blocks=0 start_indices_head=[0, 0, 0, 0, 0, 0, 0, 0]

So in a standalone serve setup (spec decode disabled), align-mode raw gathers can still contain 0 entries before sanitize.
That is exactly the value this PR maps to PAD_SLOT_ID=-1 before entering Mamba kernels.

@peakcrosser7

Copy link
Copy Markdown
Contributor

@tianshu-Michael-yu Thanks for the detailed explanation and the script!
You're right that null-blocks are used for padding in the block-table. However, the key point is that mamba_get_block_table_tensor() is designed to get the block-ids for reading and writing the current Mamba states. Therefore, it shouldn't contain null-blocks under normal cases.

I tried reproducing this with your script and found that the 0 (null-block) only appears during the CUDA-Graph capture phase. As shown below.

(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [mamba_attn.py:223] >>> [DEBUG] _compute_common_metadata: common_attn_metadata.num_reqs=4
(EngineCore_DP0 pid=3388700) WARNING 02-08 10:15:11 [sitecustomize.py:13] MAMBA_ALIGN_RAW_NULL_DETECTED null_count=4 raw_shape=(4, 1) num_speculative_blocks=0 start_indices_head=[0, 0, 0, 0]
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [utils.py:863] >>> [DEBUG] mamba_get_block_table_tensor: seq_lens=tensor([1, 1, 1, 1], device='cuda:0', dtype=torch.int32) indices_to_gather=tensor([[0],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [utils.py:863]         [0],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [utils.py:863]         [0],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [utils.py:863]         [0]], device='cuda:0'), ret=tensor([[0],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [utils.py:863]         [0],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [utils.py:863]         [0],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:11 [utils.py:863]         [0]], device='cuda:0', dtype=torch.int32)

However, I haven't encountered any null-blocks during actual request execution. Even in cases involving CUDA-Graph padding, the returned tensor contains -1, as shown here.

(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [mamba_attn.py:315] >>> [DEBUG] update_block_table: metadata.num_prefills=0, metadata.num_decodes=4
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [utils.py:863] >>> [DEBUG] mamba_get_block_table_tensor: seq_lens=tensor([612, 612, 612,   0], device='cuda:0', dtype=torch.int32) indices_to_gather=tensor([[38],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [utils.py:863]         [38],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [utils.py:863]         [38],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [utils.py:863]         [ 0]], device='cuda:0'), ret=tensor([[314],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [utils.py:863]         [308],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [utils.py:863]         [311],
(EngineCore_DP0 pid=3388700) INFO 02-08 10:15:41 [utils.py:863]         [ -1]], device='cuda:0', dtype=torch.int32)

I’m not sure if returning 0 during the capture stage would cause any issues. And if you’ve observed null-blocks during actual inference requests, could you please provide a specific case? Thanks!

@tianshu-Michael-yu

Copy link
Copy Markdown
Contributor Author

@peakcrosser7 Thanks — your concern is valid, and I re-ran this on public vLLM main with a public model/dataset and captured both non-eager and eager traces.

Environment / commit

git clone https://github.com/vllm-project/vllm.git
cd vllm
git checkout 084aa19f02b00198b36bd8d742d4169d6f5a32ce   # 2026-02-08
uv venv .venv
source .venv/bin/activate
VLLM_USE_PRECOMPILED=1 uv pip install -e .

Commit used: 084aa19f02b00198b36bd8d742d4169d6f5a32ce (v0.11.0a1.dev3097+g084aa19f0).

Repro case (public model + public data)

  • Model: LiquidAI/LFM2-1.2B
  • Requests: 32 prompts from openai/gsm8k (rows API), then normal LLM.generate with n=8, max_tokens=512, TP=2.
  • Prefix caching: ON
  • Spec decoding: OFF (speculative_config=None)

Non-eager run (enforce_eager=False)

Observed summary from trace:

{
  "vllm_version": "0.11.0a1.dev3097+g084aa19f0",
  "enable_prefix_caching": true,
  "mamba_block_table_zero_events": 280,
  "mamba_block_table_zero_valid_seq_events": 280,
  "linear_attn_state_indices_zero_events": 210,
  "engine_args_used": {
    "model": "LiquidAI/LFM2-1.2B",
    "tensor_parallel_size": 2,
    "enforce_eager": false
  }
}

Example event lines:

{"kind":"mamba_attn_get_block_table_tensor","mamba_cache_mode":"align","seq_lens_min":1,"seq_lens_max":1,"zeros_total":256,"zeros_valid_seq":256,"sample_row0":[0]}
{"kind":"mamba_attn_get_block_table_tensor","mamba_cache_mode":"align","seq_lens_min":1,"seq_lens_max":1,"zeros_total":256,"zeros_valid_seq":256,"sample_row0":[0]}

Engine log confirms this run is non-eager with CUDA graph enabled (enforce_eager=False, cudagraph_mode=FULL_AND_PIECEWISE).

Eager run (enforce_eager=True)

With the same model/requests/sampling except eager mode:

{
  "vllm_version": "0.11.0a1.dev3097+g084aa19f0",
  "enable_prefix_caching": true,
  "mamba_block_table_zero_events": 0,
  "mamba_block_table_zero_valid_seq_events": 0,
  "linear_attn_state_indices_zero_events": 0,
  "engine_args_used": {
    "model": "LiquidAI/LFM2-1.2B",
    "tensor_parallel_size": 2,
    "enforce_eager": true
  }
}

Example eager event lines (no zeros):

{"kind":"mamba_attn_get_block_table_tensor","mamba_cache_mode":"align","seq_lens_min":64,"seq_lens_max":65,"zeros_total":0,"zeros_valid_seq":0,"sample_row0":[25]}

Why null placeholders exist in real inference paths

This part is independent of whether they are gathered in a given step:

  • BlockPool reserves a shared null_block (id 0): vllm/v1/core/block_pool.py (self.null_block = ..., self.null_block.is_null = True).
  • Mamba align-mode manager introduces/reuses null placeholders during normal request scheduling:
    • find_longest_cache_hit: inserts leading null placeholders (computed.extend([block_pool.null_block] * i)) in vllm/v1/core/single_type_kv_cache_manager.py.
    • allocate_new_blocks: extends skipped ranges with self._null_block, and can move old blocks to null during speculative/running-state bookkeeping.
    • remove_skipped_blocks: frees old state block and replaces with self._null_block.

So null placeholders are created by live request cache-management logic; whether mamba_get_block_table_tensor(..., align) returns a 0 depends on the gathered index at that moment. In this public-main repro, that 0 shows up in non-eager graph-capture batches (seq_len=1) and not in eager mode.

If useful, I can also push the exact trace helper script into this PR branch so you can run it directly without retyping.

@tianshu-Michael-yu

tianshu-Michael-yu commented Feb 8, 2026

Copy link
Copy Markdown
Contributor Author

@peakcrosser7 I re-ran the same public repo with TP=1 (to remove TP=2 complexity).

Setup is unchanged except tensor_parallel_size=1:

  • commit: 084aa19f02b00198b36bd8d742d4169d6f5a32ce
  • model: LiquidAI/LFM2-1.2B
  • requests: 32 prompts from openai/gsm8k
  • spec decode: off (speculative_config=None)

TP=1, non-eager (enforce_eager=false)

Summary:

{
  "mamba_block_table_zero_events": 140,
  "mamba_block_table_zero_valid_seq_events": 140,
  "linear_attn_state_indices_zero_events": 105,
  "engine_args_used": {
    "tensor_parallel_size": 1,
    "enforce_eager": false,
    "enable_prefix_caching": true
  }
}

Sample events:

{"kind":"mamba_attn_get_block_table_tensor","mamba_cache_mode":"align","seq_lens_min":1,"seq_lens_max":1,"zeros_total":256,"zeros_valid_seq":256,"sample_row0":[0]}

(For TP=1 as well, in this repro all zero-events have seq_lens_max=1.)

TP=1, eager (enforce_eager=true)

Summary:

{
  "mamba_block_table_zero_events": 0,
  "mamba_block_table_zero_valid_seq_events": 0,
  "linear_attn_state_indices_zero_events": 0,
  "engine_args_used": {
    "tensor_parallel_size": 1,
    "enforce_eager": true,
    "enable_prefix_caching": true
  }
}

So TP=1 shows the same pattern as TP=2 in this public repo: zeros appear in non-eager path (seq_len=1 events), and do not appear in eager path.

@peakcrosser7

Copy link
Copy Markdown
Contributor

@tianshu-Michael-yu Thanks for providing the detailed test data. This indeed seems related to CUDA-Graph. I have a few specific questions:

  1. Did the null-block count include the CUDA-Graph capture? The length-1 request in your example looks like capture behavior. Normally, MambaManager should always allocate a non-null block for a length-1 request. Additionally, I noticed that TP=1 has exactly half the events of TP=2, makes it feel like there's a specific pattern here.

  2. Could you clarify the specific definitions of mamba_block_table_zero_events, mamba_block_table_zero_valid_seq_events, and linear_attn_state_indices_zero_events? I noticed the first two are identical in your results, while the latter is lower. It seems like null-blocks are appearing with non-zero indices. This seems a bit unexpected to me. Do you have a specific case for this, or could you share the script you used? Thanks!

@github-actions

Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions

Copy link
Copy Markdown

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you!

@github-actions github-actions Bot closed this Jun 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working stale Over 90 days of inactivity v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants