[Bugfix] Fix mamba cache mode null-block padding#33937
[Bugfix] Fix mamba cache mode null-block padding#33937tianshu-Michael-yu wants to merge 3 commits into
Conversation
Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a critical bugfix for Mamba cache handling. The change correctly maps the null block ID (0) to the padding slot ID (-1) in mamba_get_block_table_tensor for all cache modes. This prevents potential cross-request state corruption in Mamba kernels, which could occur if they were to read from or write to the shared null block. The fix is implemented cleanly using torch.where to avoid in-place modification of the block table, which is a good safety measure. The newly added tests in tests/v1/attention/test_mamba_block_table_tensor.py are comprehensive, covering align, all, and none cache modes, and effectively validate the correctness of the fix. The changes are well-reasoned and improve the robustness of the Mamba implementation.
|
Hi @tianshu-Michael-yu, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
Do we have a reproducible bug where this is actually happening? |
|
Yes — we were able to reproduce this in real inference workloads (before any training update), not just by code inspection. What we saw:
After mapping |
|
Hi @tianshu-Michael-yu, thanks for the PR! I’d like to ask for a few more details to better understand the situation. Specifically, was spec decoding enabled for the model you were using? |
|
Great question. Here are the concrete details from our repro path.
Why
So null-blocks are not only a tail-padding artifact; they can be in internal/earlier positions of the per-request block table in align mode. If the gathered index lands there, the output seen by Mamba kernels is |
|
I prepared a minimal runnable repro + log trace that shows exactly where the null-block appears. Repro script (spec decode disabled)#!/usr/bin/env python3
import torch
from vllm.v1.attention.backends.utils import PAD_SLOT_ID, mamba_get_block_table_tensor
from vllm.v1.kv_cache_interface import MambaSpec
spec = MambaSpec(block_size=4, shapes=((1,),), dtypes=(torch.float16,))
spec_decode_enabled = spec.num_speculative_blocks > 0
# null block id 0 is in the middle (index 1) for req0
block_table = torch.tensor(
[
[11, 0, 15, 19],
[21, 22, 23, 24],
],
dtype=torch.int32,
)
seq_lens = torch.tensor([5, 13], dtype=torch.int32)
start_indices = torch.clamp((seq_lens - 1) // spec.block_size, min=0)
indices_to_gather = start_indices.unsqueeze(1)
raw = torch.gather(block_table, 1, indices_to_gather)
current = mamba_get_block_table_tensor(block_table, seq_lens, spec, "align")
expected_fixed = torch.where(raw == 0, raw.new_full((), PAD_SLOT_ID), raw)
print("=== Repro: Mamba align-mode null block gather ===")
print(f"spec_decode_enabled: {spec_decode_enabled}")
print(f"num_speculative_blocks: {spec.num_speculative_blocks}")
print(f"PAD_SLOT_ID: {PAD_SLOT_ID}")
print(f"block_table:\n{block_table}")
print(f"seq_lens: {seq_lens.tolist()}")
print(f"start_indices: {start_indices.tolist()}")
print(f"indices_to_gather:\n{indices_to_gather}")
print(f"raw_gather_before_sanitize:\n{raw}")
print(f"mamba_get_block_table_tensor(current):\n{current}")
print(f"expected_if_fixed(0->-1):\n{expected_fixed}")
for i in range(block_table.shape[0]):
idx = int(start_indices[i])
raw_val = int(raw[i, 0])
cur_val = int(current[i, 0])
exp_val = int(expected_fixed[i, 0])
print(
f"req={i}: gathered block_table[{i},{idx}]={raw_val} "
f"(null={raw_val==0}) current={cur_val} expected_fixed={exp_val}"
)Run commandpython repro_mamba_align_null_block.py | tee repro_mamba_align_null_block.logCaptured log trace (from my run)This shows the null-block is at an internal position (index 1 here), and with spec decode disabled we can still gather that |
|
@peakcrosser7 Thanks for the push on this. You’re right that the best evidence should come from a standalone vLLM serve path. I re-ran this with a standalone OpenAI server flow, using public vLLM source + LFM2 model. Standalone repro (serve path)git clone https://github.com/vllm-project/vllm.git
cd vllm
# checkout this PR branch/commit under test
uv venv .venv
source .venv/bin/activate
VLLM_USE_PRECOMPILED=1 uv pip install -e .
cat > sitecustomize.py <<'PY'
import logging, os, torch
if os.getenv("VLLM_DEBUG_MAMBA_RAW_NULL", "0") == "1":
import vllm.v1.attention.backends.utils as u
orig = u.mamba_get_block_table_tensor
logger = logging.getLogger("vllm.mamba.rawnull")
def wrapped(block_table, seq_lens, kv_cache_spec, mamba_cache_mode):
if mamba_cache_mode == "align":
start_indices = torch.clamp((seq_lens - 1) // kv_cache_spec.block_size, min=0)
offsets = torch.arange(1 + kv_cache_spec.num_speculative_blocks, device=block_table.device)
raw = torch.gather(block_table, 1, start_indices.unsqueeze(1) + offsets)
null_count = int((raw == 0).sum().item())
if null_count > 0:
logger.warning(
"MAMBA_ALIGN_RAW_NULL_DETECTED null_count=%d raw_shape=%s num_speculative_blocks=%d start_indices_head=%s",
null_count, tuple(raw.shape), int(kv_cache_spec.num_speculative_blocks),
start_indices[:8].cpu().tolist(),
)
return orig(block_table, seq_lens, kv_cache_spec, mamba_cache_mode)
u.mamba_get_block_table_tensor = wrapped
PY
export PYTHONPATH=$PWD:$PYTHONPATH
export VLLM_DEBUG_MAMBA_RAW_NULL=1
python -m vllm.entrypoints.openai.api_server \
--model LiquidAI/LFM2-1.2B \
--enable-prefix-caching \
--mamba-cache-mode align \
--tensor-parallel-size 1 \
--dtype bfloat16 \
--max-model-len 4096 \
--gpu-memory-utilization 0.7Then send normal completion requests to What the server log shows (from my run)So in a standalone serve setup (spec decode disabled), align-mode raw gathers can still contain |
|
@tianshu-Michael-yu Thanks for the detailed explanation and the script! I tried reproducing this with your script and found that the 0 (null-block) only appears during the CUDA-Graph capture phase. As shown below. However, I haven't encountered any null-blocks during actual request execution. Even in cases involving CUDA-Graph padding, the returned tensor contains -1, as shown here. I’m not sure if returning 0 during the capture stage would cause any issues. And if you’ve observed null-blocks during actual inference requests, could you please provide a specific case? Thanks! |
|
@peakcrosser7 Thanks — your concern is valid, and I re-ran this on public vLLM main with a public model/dataset and captured both non-eager and eager traces. Environment / commitgit clone https://github.com/vllm-project/vllm.git
cd vllm
git checkout 084aa19f02b00198b36bd8d742d4169d6f5a32ce # 2026-02-08
uv venv .venv
source .venv/bin/activate
VLLM_USE_PRECOMPILED=1 uv pip install -e .Commit used: Repro case (public model + public data)
Non-eager run (
|
|
@peakcrosser7 I re-ran the same public repo with TP=1 (to remove TP=2 complexity). Setup is unchanged except
TP=1, non-eager (
|
|
@tianshu-Michael-yu Thanks for providing the detailed test data. This indeed seems related to CUDA-Graph. I have a few specific questions:
|
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
|
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
Purpose
In
mamba_get_block_table_tensor, block id0is reserved forBlockPool.null_block(never allocated) but can appear in block tables as a placeholder (e.g. mamba align mode). Mamba kernels treatPAD_SLOT_ID(-1) as padding; if we pass block id0through, kernels can read/write state for the shared null block, causing cross-request state corruption.This PR maps block id
0toPAD_SLOT_IDfor all mamba cache modes.Test Plan
python -m pytest tests/v1/attention/test_mamba_block_table_tensor.pyTest Result
2 passed(CPU)