Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@

logger = init_logger(__name__)

# Shared across all CutlassMLAImpl instances (all layers) to avoid
# 61× per-layer allocation (~15 GiB → ~256 MiB for DeepSeek-R1).
g_cutlass_decode_out: torch.Tensor | None = None


class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
Expand Down Expand Up @@ -162,10 +166,6 @@
# Share workspace buffer across all executions
self._workspace = g_sm100_workspace

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None

def _sm100_cutlass_mla_decode(
self,
Expand Down Expand Up @@ -223,15 +223,17 @@
if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype
)
# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
# Reuse a single zero-init output buffer shared across all layers
# to prevent NaN in padding slots. Shared buffer reduces memory
# from ~15 GiB (per-layer) to ~256 MiB for DeepSeek-R1.
global g_cutlass_decode_out
if (
self._decode_out is None
or self._decode_out.shape[0] < B_q
or self._decode_out.dtype != dtype
g_cutlass_decode_out is None
or g_cutlass_decode_out.shape[0] < B_q
or g_cutlass_decode_out.dtype != dtype
):
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
out = self._decode_out[:B_q]
g_cutlass_decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)

Check failure on line 235 in vllm/v1/attention/backends/mla/cutlass_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/mla/cutlass_mla.py:235:89: E501 Line too long (92 > 88)
out = g_cutlass_decode_out[:B_q]
Comment on lines +229 to +236
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of a global mutable variable g_cutlass_decode_out without any synchronization mechanism can lead to race conditions in a multi-threaded environment. If multiple forward passes are executed concurrently in different threads (e.g., with multiple adapters), one thread might be reading g_cutlass_decode_out while another is reallocating it. This could cause crashes or silent data corruption.

To ensure thread safety, this critical section should be protected by a lock. You'll need to add import threading and define a lock at the module level (e.g., g_cutlass_decode_out_lock = threading.Lock()), then use it to guard access to the shared buffer.

lse = (
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode
Expand Down
27 changes: 15 additions & 12 deletions vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024

# Shared across all FlashInferMLAImpl instances (all layers) to avoid
# 61× per-layer allocation (~15 GiB → ~256 MiB for DeepSeek-R1).
g_fi_decode_out: torch.Tensor | None = None


class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
Expand Down Expand Up @@ -152,10 +156,6 @@ def __init__(
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None

def forward_mqa(
self,
Expand Down Expand Up @@ -192,15 +192,18 @@ def forward_mqa(
if self.kv_cache_dtype.startswith("fp8"):
self.bmm2_scale *= layer._k_scale_float

# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
# q is 4D: (batch, q_len_per_req, num_heads, head_dim)
# Reuse a single zero-init output buffer shared across all layers
# to prevent NaN in padding slots (seq_lens=0) from contaminating
# downstream per-tensor reductions. Shared buffer reduces memory
# from ~15 GiB (per-layer) to ~256 MiB for DeepSeek-R1.
#
# FlashInfer has a bug where out= validation hardcodes 3D shape
# (batch, num_heads, kv_lora_rank), but the kernel writes 4D
# (batch, q_len, num_heads, kv_lora_rank) when q_len > 1.
# So we can only pass out= for single-token decode (q_len == 1).
# For q_len > 1, we zero padding slots after the kernel returns.
# TODO: upstream fix to FlashInfer
global g_fi_decode_out
B, q_len_per_req = q.shape[0], q.shape[1]
out_kwargs: dict[str, torch.Tensor] = {}
if q_len_per_req == 1:
Expand All @@ -210,18 +213,18 @@ def forward_mqa(
else q.dtype
)
if (
self._decode_out is None
or self._decode_out.shape[0] < B
or self._decode_out.dtype != dtype
g_fi_decode_out is None
or g_fi_decode_out.shape[0] < B
or g_fi_decode_out.dtype != dtype
):
Comment on lines 215 to 219
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check for re-allocating the shared g_fi_decode_out buffer is incomplete. It only considers the batch size (B) and dtype, but not the number of heads (q.shape[2]) or self.kv_lora_rank. If different models or layers with different num_heads or kv_lora_rank are used in the same process, an incorrectly sized buffer could be reused, leading to shape mismatch errors or silent correctness issues inside the trtllm_batch_decode_with_kv_cache_mla kernel.

To prevent this, the check should be extended to also verify that the number of heads and the LoRA rank match the existing buffer's dimensions.

            if (
                g_fi_decode_out is None
                or g_fi_decode_out.shape[0] < B
                or g_fi_decode_out.shape[1] != q.shape[2]
                or g_fi_decode_out.shape[2] != self.kv_lora_rank
                or g_fi_decode_out.dtype != dtype
            ):

self._decode_out = torch.zeros(
g_fi_decode_out = torch.zeros(
B,
q.shape[2],
self.kv_lora_rank,
dtype=dtype,
device=q.device,
)
out_kwargs["out"] = self._decode_out[:B]
out_kwargs["out"] = g_fi_decode_out[:B]

o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
Expand Down
Loading