Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@
from vllm.v1.kv_cache_interface import (
AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.outputs import (
Expand Down Expand Up @@ -2438,12 +2438,11 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str
k_tensor_split_factor = 2
v_tensor_split_factor = 2
elif self.use_sparse:
# for deepseek v3.2, DSA use FullAttentionSpec
# FullAttentionSpec allocate 2 * mla page size bytes,
# and we use half of that for k cache in DSA
dsa_k_cache_factor = 2
k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank
v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
]
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
else:
# for other deepseek models, use MLAAttentionSpec
Expand Down Expand Up @@ -2581,9 +2580,14 @@ def _reshape_kv_cache_tensors(
v_cache = raw_v_tensor.view(dtype).view(v_shape)

if self.use_sparse and raw_dsa_k_tensor is not None:
dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128)
dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize
dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(dtype).view(dsa_k_cache_shape)
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
dsa_k_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
index_head_dim,
)
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
Expand Down Expand Up @@ -2832,12 +2836,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
if self.use_sparse:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is the final way.
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
head_size=sparse_sum_head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
)
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
kv_cache_spec[layer_name] = spec
Expand All @@ -2854,6 +2859,16 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:

return kv_cache_spec

def _get_sparse_kv_cache_ratio(self) -> list[int]:
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
# when calculating the ratio,for example:
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
return [
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to return a tuple instead of a list.


def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],
Expand Down