diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index d22e9a96e0f3..ef4ac2f118c8 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1003,6 +1003,7 @@ def set_splitting_ops_for_v1( # https://github.com/vllm-project/vllm/issues/33267 if not self.use_inductor_graph_partition: self.splitting_ops.append("vllm::unified_kv_cache_update") + self.splitting_ops.append("vllm::unified_mla_kv_cache_update") elif len(self.splitting_ops) == 0: if ( diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index d444e20dad9e..75af6f3deabd 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -435,6 +435,17 @@ def forward( attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # Write the latent and rope to kv cache + if self_kv_cache.numel() > 0: + self.impl.do_kv_cache_update( + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + self._k_scale, + ) + if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.forward_impl( @@ -451,6 +462,11 @@ def forward( q, kv_c_normed, k_pe, self_kv_cache, attn_metadata ) else: + kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( + kv_c_normed, + k_pe, + self.layer_name, + ) if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) torch.ops.vllm.unified_mla_attention_with_output( @@ -459,6 +475,7 @@ def forward( k_pe, output, self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) return output else: @@ -467,6 +484,7 @@ def forward( kv_c_normed, k_pe, self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) def forward_impl( @@ -520,17 +538,6 @@ def forward_impl( k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=self._k_scale, - ) - if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla": kv_cache = kv_cache.view(current_platform.fp8_dtype()) @@ -821,13 +828,49 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): out.copy_(out_new) # Copy result +def unified_mla_kv_cache_update( + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + _, layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) + if layer_slot_mapping is not None and kv_cache.numel() > 0: + layer.impl.do_kv_cache_update( + kv_c_normed, + k_pe, + kv_cache, + layer_slot_mapping, + layer.kv_cache_dtype, + layer._k_scale, + ) + return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) + + +def unified_mla_kv_cache_update_fake( + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + + +direct_register_custom_op( + op_name="unified_mla_kv_cache_update", + op_func=unified_mla_kv_cache_update, + fake_impl=unified_mla_kv_cache_update_fake, + mutates_args=[], +) + + @maybe_transfer_kv_layer def unified_mla_attention( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: + del kv_cache_dummy_dep attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) @@ -839,6 +882,7 @@ def unified_mla_attention_fake( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q).contiguous() @@ -861,7 +905,9 @@ def unified_mla_attention_with_output( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: + del kv_cache_dummy_dep attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) layer.forward_impl( q, @@ -883,6 +929,7 @@ def unified_mla_attention_with_output_fake( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: return diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 43fa5991112a..88111d6913b2 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -811,6 +811,29 @@ def forward_mqa( """MQA-style decode forward pass.""" raise NotImplementedError + def do_kv_cache_update( + self, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + ) -> None: + if kv_cache.numel() == 0: + return + + from vllm import _custom_ops as ops + + ops.concat_and_cache_mla( + kv_c_normed, + k_pe.squeeze(1), + kv_cache, + slot_mapping.flatten(), + kv_cache_dtype=kv_cache_dtype, + scale=k_scale, + ) + class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): """Sparse MLA attention implementation with only forward_mqa method. @@ -856,6 +879,29 @@ def forward_mqa( """MQA-style decode forward pass.""" raise NotImplementedError + def do_kv_cache_update( + self, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + ) -> None: + if kv_cache.numel() == 0: + return + + from vllm import _custom_ops as ops + + ops.concat_and_cache_mla( + kv_c_normed, + k_pe.squeeze(1), + kv_cache, + slot_mapping.flatten(), + kv_cache_dtype=kv_cache_dtype, + scale=k_scale, + ) + def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: return kv_cache_dtype.startswith("fp8")