Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,7 @@ def set_splitting_ops_for_v1(
# https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition:
self.splitting_ops.append("vllm::unified_kv_cache_update")
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")

elif len(self.splitting_ops) == 0:
if (
Expand Down
69 changes: 58 additions & 11 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,17 @@ def forward(
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]

# Write the latent and rope to kv cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing from the indirect call path below. A new custom op needs to be created (like unified_kv_cache_update for GQA - see attention.py), and then that should be called before calling the MLA attention op

if self_kv_cache.numel() > 0:
self.impl.do_kv_cache_update(
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
self._k_scale,
)

if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.forward_impl(
Expand All @@ -451,6 +462,11 @@ def forward(
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
kv_c_normed,
k_pe,
self.layer_name,
)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
Expand All @@ -459,6 +475,7 @@ def forward(
k_pe,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
else:
Expand All @@ -467,6 +484,7 @@ def forward(
kv_c_normed,
k_pe,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)

def forward_impl(
Expand Down Expand Up @@ -520,17 +538,6 @@ def forward_impl(
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]

# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=self._k_scale,
)

if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())

Expand Down Expand Up @@ -821,13 +828,49 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
out.copy_(out_new) # Copy result


def unified_mla_kv_cache_update(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
_, layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None and kv_cache.numel() > 0:
layer.impl.do_kv_cache_update(
kv_c_normed,
k_pe,
kv_cache,
layer_slot_mapping,
layer.kv_cache_dtype,
layer._k_scale,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)


def unified_mla_kv_cache_update_fake(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)


direct_register_custom_op(
op_name="unified_mla_kv_cache_update",
op_func=unified_mla_kv_cache_update,
fake_impl=unified_mla_kv_cache_update_fake,
mutates_args=[],
)


@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
del kv_cache_dummy_dep
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)

Expand All @@ -839,6 +882,7 @@ def unified_mla_attention_fake(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()

Expand All @@ -861,7 +905,9 @@ def unified_mla_attention_with_output(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
del kv_cache_dummy_dep
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
layer.forward_impl(
q,
Expand All @@ -883,6 +929,7 @@ def unified_mla_attention_with_output_fake(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
return

Expand Down
46 changes: 46 additions & 0 deletions vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,29 @@ def forward_mqa(
"""MQA-style decode forward pass."""
raise NotImplementedError

def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return

from vllm import _custom_ops as ops

ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)


class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""Sparse MLA attention implementation with only forward_mqa method.
Expand Down Expand Up @@ -856,6 +879,29 @@ def forward_mqa(
"""MQA-style decode forward pass."""
raise NotImplementedError

def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return

from vllm import _custom_ops as ops

ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
Expand Down