diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index f8f398d6ac8..f62f8b18e2b 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -48,6 +48,7 @@ The following table lists additional configuration options available in vLLM Asc | `num_wait_worker_iterations` | int | `30` | The forward iterations when the EPLB worker will finish CPU tasks. In our test default value 30 can cover most cases. | | `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. | | `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. | +| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | The details of each configuration option are as follows: @@ -105,7 +106,8 @@ An example of additional configuration is as follows: "embedding_tensor_parallel_size": 8, "mlp_tensor_parallel_size": 8, }, + "enable_kv_nz": False, "multistream_overlap_shared_expert": True, - "refresh": False, + "refresh": False } ``` diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py index 99b383ba2be..6ef9521331d 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -98,7 +100,7 @@ def test_mla_preprocess_kernel(): bias1=bias1, ctkv_scale=ctkv_scale, q_nope_scale=qnope_scale, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="per_tensor_quant_asymm", enable_inner_out=False, q_out0=q_nope_out, diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py index b18c63f64f3..196ffafce3c 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -82,7 +84,7 @@ def test_mla_preprocess_kernel(): None, None, None, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="no_quant", enable_inner_out=False, q_out0=q_nope_out, diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py index 9eb7e1caffb..0475361792b 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -99,7 +101,7 @@ def test_mla_preprocess_kernel(): bias1=bias1, ctkv_scale=ctkv_scale, q_nope_scale=qnope_scale, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="per_tensor_quant_asymm", enable_inner_out=True, q_out0=q_nope_out, diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 1a337dea4f1..5cccc02797a 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -39,6 +39,7 @@ def test_init_ascend_config_without_additional_config(self): ascend_config = init_ascend_config(test_vllm_config) self.assertIsNone(ascend_config.expert_map_path) self.assertFalse(ascend_config.multistream_overlap_shared_expert) + self.assertFalse(ascend_config.enable_kv_nz) ascend_compilation_config = ascend_config.ascend_compilation_config self.assertTrue(ascend_compilation_config.fuse_norm_quant) @@ -53,6 +54,7 @@ def test_init_ascend_config_with_additional_config(self): "multistream_overlap_shared_expert": True, "expert_map_path": "test_expert_map_path", "refresh": True, + "enable_kv_nz": False } ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") @@ -61,6 +63,7 @@ def test_init_ascend_config_with_additional_config(self): ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.fuse_norm_quant) + self.assertFalse(ascend_config.enable_kv_nz) @_clean_up_ascend_config def test_init_ascend_config_enable_npugraph_ex(self): diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8be434a18c9..fec3ade854c 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -13,18 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import TYPE_CHECKING, Optional from vllm.logger import logger from vllm.triton_utils import HAS_TRITON +if TYPE_CHECKING: + from vllm.config import VllmConfig + class AscendConfig: """ Configuration Object for additional_config from vllm.configs. """ - def __init__(self, vllm_config): + def __init__(self, vllm_config: "VllmConfig"): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} xlite_graph_config = additional_config.get("xlite_graph_config", {}) @@ -121,6 +124,19 @@ def __init__(self, vllm_config): self.enable_async_exponential = bool( additional_config.get("enable_async_exponential", False)) + self.enable_kv_nz = additional_config.get("enable_kv_nz", False) + if self.enable_kv_nz: + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") + if not vllm_config.model_config.is_deepseek_mla or use_sparse: + raise RuntimeError( + "enable_kv_nz is only supported for mla currently.") + if vllm_config.kv_transfer_config is None \ + or not vllm_config.kv_transfer_config.is_kv_consumer: + raise NotImplementedError( + "enable_kv_nz is only supported in pd scenario and can " + "only be used in D node.") + class FinegrainedTPConfig: """ diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 76f2e4102ac..5535660e3c3 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -745,6 +745,7 @@ def __init__( ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled + self.enable_kv_nz = ascend_config.enable_kv_nz self.ring_mla_mask_size = 512 @@ -1073,7 +1074,7 @@ def exec_kv_decode( # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, @@ -1143,37 +1144,57 @@ def _forward_decode( # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) + if self.enable_kv_nz: + nz_fmt_last_dim = 16 + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // nz_fmt_last_dim, + block_size, nz_fmt_last_dim) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // nz_fmt_last_dim, + block_size, nz_fmt_last_dim) + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + attn_output_shape: tuple | None = None if attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, ] and self.speculative_config is not None: - # Input shape: [num_tokens, num_heads, dim] - # Output shape: [num_heads, num_tokens, dim] # The right part layout indicates the layout of the attention # output. It is set to NTD to avoid the need for a transpose # operation after attention. input_layout = "TND_NTD" # TODO: If the driver is upgraded later, the contiguous function can be deleted. + # Input shape: [num_tokens, num_heads, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, -1) + # Output shape: [num_heads, num_tokens, dim] + attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank) sparse_mode = 3 spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q else: - # Input shape: [num_reqs, num_heads, seq_len, dim] - # Output shape: [num_heads, num_reqs, seq_len, dim] # The output layout is set to NBSD to eliminate the need for a # transpose operation after attention. - input_layout = "BNSD_NBSD" - q_nope = q_nope.view(num_tokens, self.num_heads, 1, - -1).contiguous() - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + if self.enable_kv_nz: + # Input shape: [num_tokens, seq_len, num_heads, dim] + input_layout = "BSND_NBSD" + q_nope = q_nope.view(num_tokens, 1, self.num_heads, + -1).contiguous() + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + # Input shape: [num_tokens, num_heads, seq_len, dim] + input_layout = "BNSD_NBSD" + q_nope = q_nope.view(num_tokens, self.num_heads, 1, + -1).contiguous() + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + # Output shape: [num_heads, num_tokens, seq_len, dim] + attn_output_shape = (self.num_heads, num_tokens, 1, + self.kv_lora_rank) sparse_mode = 0 spec_attn_mask = None @@ -1215,10 +1236,9 @@ def _forward_decode( else: update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty( - (q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]), - dtype=q_nope.dtype, - device=q_nope.device) + attn_output = torch.empty(attn_output_shape, + dtype=q_nope.dtype, + device=q_nope.device) softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) @@ -1297,7 +1317,7 @@ def _mla_preprocess_only_decode(self, hidden_states, kv_cache, bias1=self.qb_qt_bias, ctkv_scale=self.ctkv_scale, q_nope_scale=self.q_nope_scale, - cache_mode="krope_ctkv", + cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", quant_mode="per_tensor_quant_asymm", q_out0=decode_q_nope, kv_cache_out0=decode_k_nope, diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 8f2696102ad..0b2ae66126d 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -382,8 +382,10 @@ def _handle_request(self, req_meta: dict[str, Any]): logger.debug( f"Finished transferring KV cache for request {request_id}.") except Exception as e: - logger.error("Failed to transfer KV cache for request " - f"{request_id}: {e}") + logger.error( + "Failed to transfer KV cache for request " + f"{request_id}: {e}", + exc_info=True) finally: # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the @@ -489,97 +491,116 @@ def _transfer_kv_cache(self, req_meta: dict[str, Any]): request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, get_ip(), self.tp_rank, session_id) - # Determine if the current position is the offset position at the end of the KV transmission. + # Determine if the current position is the offset position at the end of + # the KV transmission. is_kv_transfer_end = ( global_offset == tp_num_need_pulls * self._prefill_pp_size - 1) need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end - # need_nz_cache maybe caused error in non-MLA models - if need_cat_cache: - self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls) - - def _cat_kv_cache(self, block_ids: list[list[int]], - tp_num_need_pulls: int): + need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end + if need_nz_cache or need_cat_cache: + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, + need_cat_cache, need_nz_cache) + + def reformat_kv_cache(self, + block_ids: list[list[int]], + tp_num_need_pulls: int, + need_cat_cache: bool = False, + need_nz_cache: bool = False): # Get necessary parameters k_cache = list(self.kv_caches.values())[0][0] dtype = k_cache.dtype device = k_cache.device - head_dim = self.model_config.hf_text_config.head_dim - block_size = self.vllm_config.cache_config.block_size - num_kv_head = max( - self.model_config.hf_text_config.num_key_value_heads // - self.tp_size, 1) flat_block_ids = [item for sublist in block_ids for item in sublist] - block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32) + block_ids_tensor = torch.tensor(flat_block_ids, + dtype=torch.int32, + device=device) num_blocks = len(flat_block_ids) - block_len = num_blocks * block_size + num_tokens = num_blocks * self.block_size # Create device tensors for copy operations - block_table = block_ids_tensor.view(1, -1).to(device=device) - block_len_tensor = torch.tensor([block_len], - dtype=torch.int32).to(device=device) - seq_start_tensor = torch.tensor([0], - dtype=torch.int32).to(device=device) + block_table = block_ids_tensor.view(1, -1) + block_len_tensor = torch.tensor([num_tokens], + dtype=torch.int32, + device=device) + seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device) # Initialize buffers - k_buffer = torch.empty(block_len, - num_kv_head, - head_dim, - dtype=dtype, - device=device) - v_buffer = torch.empty(block_len, - num_kv_head, - head_dim, - dtype=dtype, - device=device) + k_buffer = torch.empty( + (num_tokens, self.num_kv_heads, self.k_head_dim), + dtype=dtype, + device=device) + v_buffer = torch.empty( + (num_tokens, self.num_kv_heads, self.v_head_dim), + dtype=dtype, + device=device) # Create slot mapping for reshape operations - block_offsets = torch.arange(0, block_size, dtype=torch.int32) + block_offsets = torch.arange(0, + self.block_size, + dtype=torch.int32, + device=device) slot_mapping = (block_offsets.reshape( - (1, block_size)) + block_ids_tensor.reshape( - (num_blocks, 1)) * block_size) - slot_mapping = slot_mapping.flatten().to(device=device) + (1, self.block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * self.block_size).flatten() + + # FIXME: Right now, if we skip synchronization at this point, the system + # will crash in GQA scenarios. However, we still haven't identified the + # root cause. + torch.npu.synchronize() # Process each layer in the KV cache for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): # Load cache data into buffers - torch_npu.atb.npu_paged_cache_load( - k_cache_layer, - v_cache_layer, - block_table, - block_len_tensor, - seq_starts=seq_start_tensor, - key=k_buffer, - value=v_buffer, - ) - - # Transpose KV cache - k_buffer = self._transpose_kv_cache_between_head( - k_buffer, num_blocks, block_size, block_len, num_kv_head, - tp_num_need_pulls) - v_buffer = self._transpose_kv_cache_between_head( - v_buffer, num_blocks, block_size, block_len, num_kv_head, - tp_num_need_pulls) - - # Reshape and cache the processed buffers - torch_npu._npu_reshape_and_cache( - key=k_buffer, - value=v_buffer, - key_cache=k_cache_layer, - value_cache=v_cache_layer, - slot_indices=slot_mapping, - ) - + torch_npu.atb.npu_paged_cache_load(k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k_buffer, + value=v_buffer) + if need_cat_cache: + self._cat_kv_cache(k_cache_layer, v_cache_layer, k_buffer, + v_buffer, tp_num_need_pulls, num_blocks, + num_tokens, slot_mapping) + if need_nz_cache: + self._nz_kv_cache(k_cache_layer, v_cache_layer, k_buffer, + v_buffer, slot_mapping) # Clean up buffers del k_buffer, v_buffer - def _transpose_kv_cache_between_head( - self, buffer: torch.Tensor, num_blocks: int, block_size: int, - block_len: int, num_kv_head: int, - tp_num_need_pulls: int) -> torch.Tensor: - buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1) - buffer.transpose_(1, 2) - return buffer.contiguous().view(block_len, num_kv_head, -1) + def _cat_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, + tp_num_need_pulls, num_blocks, num_tokens, slot_mapping): + + def _transpose_kv_cache_between_head( + buffer: torch.Tensor) -> torch.Tensor: + buffer = buffer.view(num_blocks, tp_num_need_pulls, + self.block_size, -1) + buffer.transpose_(1, 2) + return buffer.contiguous().view(num_tokens, self.num_kv_heads, -1) + + # Transpose KV cache + k_buffer = _transpose_kv_cache_between_head(k_buffer) + v_buffer = _transpose_kv_cache_between_head(v_buffer) + + # Reshape and cache the processed buffers + torch_npu._npu_reshape_and_cache(key=k_buffer, + value=v_buffer, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping) + + def _nz_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, + slot_mapping): + nz_fmt_last_dim = 16 + k_cache_layer = k_cache_layer.view( + -1, self.k_head_dim * self.num_kv_heads // nz_fmt_last_dim, + self.block_size, nz_fmt_last_dim) + v_cache_layer = v_cache_layer.view( + -1, self.v_head_dim * self.num_kv_heads // nz_fmt_last_dim, + self.block_size, nz_fmt_last_dim) + torch_npu.npu_scatter_pa_kv_cache(k_buffer, v_buffer, k_cache_layer, + v_cache_layer, slot_mapping) def _get_remote_metadata(self, remote_host: str, remote_handshake_port: int) -> None: