Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ The following table lists additional configuration options available in vLLM Asc
| `num_wait_worker_iterations` | int | `30` | The forward iterations when the EPLB worker will finish CPU tasks. In our test default value 30 can cover most cases. |
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |

The details of each configuration option are as follows:

Expand Down Expand Up @@ -105,7 +106,8 @@ An example of additional configuration is as follows:
"embedding_tensor_parallel_size": 8,
"mlp_tensor_parallel_size": 8,
},
"enable_kv_nz": False,
"multistream_overlap_shared_expert": True,
"refresh": False,
"refresh": False
}
```
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc

import pytest
import torch
import torch_npu

Expand All @@ -8,8 +9,9 @@
enable_custom_op()


@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"])
@torch.inference_mode()
def test_mla_preprocess_kernel():
def test_mla_preprocess_kernel(cache_mode: str):
token_num = 1
head_num = 2
N_7168 = 7168
Expand Down Expand Up @@ -98,7 +100,7 @@ def test_mla_preprocess_kernel():
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
cache_mode=cache_mode,
quant_mode="per_tensor_quant_asymm",
enable_inner_out=False,
q_out0=q_nope_out,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc

import pytest
import torch
import torch_npu

Expand All @@ -8,8 +9,9 @@
enable_custom_op()


@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"])
@torch.inference_mode()
def test_mla_preprocess_kernel():
def test_mla_preprocess_kernel(cache_mode: str):
token_num = 1
head_num = 2
N_7168 = 7168
Expand Down Expand Up @@ -82,7 +84,7 @@ def test_mla_preprocess_kernel():
None,
None,
None,
cache_mode="krope_ctkv",
cache_mode=cache_mode,
quant_mode="no_quant",
enable_inner_out=False,
q_out0=q_nope_out,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc

import pytest
import torch
import torch_npu

Expand All @@ -8,8 +9,9 @@
enable_custom_op()


@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"])
@torch.inference_mode()
def test_mla_preprocess_kernel():
def test_mla_preprocess_kernel(cache_mode: str):
token_num = 1
head_num = 2
N_7168 = 7168
Expand Down Expand Up @@ -99,7 +101,7 @@ def test_mla_preprocess_kernel():
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
cache_mode=cache_mode,
quant_mode="per_tensor_quant_asymm",
enable_inner_out=True,
q_out0=q_nope_out,
Expand Down
3 changes: 3 additions & 0 deletions tests/ut/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_init_ascend_config_without_additional_config(self):
ascend_config = init_ascend_config(test_vllm_config)
self.assertIsNone(ascend_config.expert_map_path)
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
self.assertFalse(ascend_config.enable_kv_nz)

ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertTrue(ascend_compilation_config.fuse_norm_quant)
Expand All @@ -53,6 +54,7 @@ def test_init_ascend_config_with_additional_config(self):
"multistream_overlap_shared_expert": True,
"expert_map_path": "test_expert_map_path",
"refresh": True,
"enable_kv_nz": False
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
Expand All @@ -61,6 +63,7 @@ def test_init_ascend_config_with_additional_config(self):

ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertFalse(ascend_compilation_config.fuse_norm_quant)
self.assertFalse(ascend_config.enable_kv_nz)

@_clean_up_ascend_config
def test_init_ascend_config_enable_npugraph_ex(self):
Expand Down
20 changes: 18 additions & 2 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import TYPE_CHECKING, Optional

from vllm.logger import logger
from vllm.triton_utils import HAS_TRITON

if TYPE_CHECKING:
from vllm.config import VllmConfig


class AscendConfig:
"""
Configuration Object for additional_config from vllm.configs.
"""

def __init__(self, vllm_config):
def __init__(self, vllm_config: "VllmConfig"):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}

xlite_graph_config = additional_config.get("xlite_graph_config", {})
Expand Down Expand Up @@ -121,6 +124,19 @@ def __init__(self, vllm_config):
self.enable_async_exponential = bool(
additional_config.get("enable_async_exponential", False))

self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
if self.enable_kv_nz:
use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
raise RuntimeError(
"enable_kv_nz is only supported for mla currently.")
if vllm_config.kv_transfer_config is None \
or not vllm_config.kv_transfer_config.is_kv_consumer:
raise NotImplementedError(
"enable_kv_nz is only supported in pd scenario and can "
"only be used in D node.")


class FinegrainedTPConfig:
"""
Expand Down
56 changes: 38 additions & 18 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ def __init__(
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.enable_kv_nz

self.ring_mla_mask_size = 512

Expand Down Expand Up @@ -1073,7 +1074,7 @@ def exec_kv_decode(
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA"
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
Expand Down Expand Up @@ -1143,37 +1144,57 @@ def _forward_decode(
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
if self.enable_kv_nz:
nz_fmt_last_dim = 16
k_nope = k_nope.view(-1, self.num_kv_heads,
self.kv_lora_rank // nz_fmt_last_dim,
block_size, nz_fmt_last_dim)
k_pe = k_pe.view(-1, self.num_kv_heads,
self.qk_rope_head_dim // nz_fmt_last_dim,
block_size, nz_fmt_last_dim)
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)

attn_output_shape: tuple | None = None
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
] and self.speculative_config is not None:
# Input shape: [num_tokens, num_heads, dim]
# Output shape: [num_heads, num_tokens, dim]
# The right part layout indicates the layout of the attention
# output. It is set to NTD to avoid the need for a transpose
# operation after attention.
input_layout = "TND_NTD"
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
# Input shape: [num_tokens, num_heads, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
# Output shape: [num_heads, num_tokens, dim]
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# Input shape: [num_reqs, num_heads, seq_len, dim]
# Output shape: [num_heads, num_reqs, seq_len, dim]
# The output layout is set to NBSD to eliminate the need for a
# transpose operation after attention.
input_layout = "BNSD_NBSD"
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
if self.enable_kv_nz:
# Input shape: [num_tokens, seq_len, num_heads, dim]
input_layout = "BSND_NBSD"
q_nope = q_nope.view(num_tokens, 1, self.num_heads,
-1).contiguous()
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
# Input shape: [num_tokens, num_heads, seq_len, dim]
input_layout = "BNSD_NBSD"
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
# Output shape: [num_heads, num_tokens, seq_len, dim]
attn_output_shape = (self.num_heads, num_tokens, 1,
self.kv_lora_rank)
sparse_mode = 0
spec_attn_mask = None

Expand Down Expand Up @@ -1215,10 +1236,9 @@ def _forward_decode(
else:
update_graph_params_workspaces(num_tokens, workspace)

attn_output = torch.empty(
(q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]),
dtype=q_nope.dtype,
device=q_nope.device)
attn_output = torch.empty(attn_output_shape,
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.empty(num_tokens,
dtype=q_nope.dtype,
device=q_nope.device)
Expand Down Expand Up @@ -1297,7 +1317,7 @@ def _mla_preprocess_only_decode(self, hidden_states, kv_cache,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope,
Expand Down
Loading
Loading