diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1ffefd0bfcc..db701ed2189 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -53,11 +53,11 @@ from vllm.v1.kv_cache_interface import ( AttentionSpec, EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, + MLAAttentionSpec, UniformTypeKVCacheSpecs, ) from vllm.v1.outputs import ( @@ -2438,12 +2438,11 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str k_tensor_split_factor = 2 v_tensor_split_factor = 2 elif self.use_sparse: - # for deepseek v3.2, DSA use FullAttentionSpec - # FullAttentionSpec allocate 2 * mla page size bytes, - # and we use half of that for k cache in DSA - dsa_k_cache_factor = 2 - k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank - v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim + # for deepseek v3.2, we split the kv cache according to the corresponding ratio + sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio()) + k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore + sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio() + ] dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor) else: # for other deepseek models, use MLAAttentionSpec @@ -2581,9 +2580,14 @@ def _reshape_kv_cache_tensors( v_cache = raw_v_tensor.view(dtype).view(v_shape) if self.use_sparse and raw_dsa_k_tensor is not None: - dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128) - dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize - dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(dtype).view(dsa_k_cache_shape) + index_head_dim = self._get_sparse_kv_cache_ratio()[-1] + dsa_k_cache_shape = ( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + index_head_dim, + ) + dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape) kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) else: kv_caches[layer_name] = (k_cache, v_cache) @@ -2832,12 +2836,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: if self.use_sparse: # TODO(cmq): This is a hack way to fix deepseek kvcache when # using DSA. Fix the spec in vLLM is the final way. - block_size = self.vllm_config.cache_config.block_size - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, + sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio()) + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=self.block_size, num_kv_heads=1, - head_size=attn_module.head_size, + head_size=sparse_sum_head_size, dtype=self.kv_cache_dtype, + cache_dtype_str=self.vllm_config.cache_config.cache_dtype, ) elif spec := attn_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec @@ -2854,6 +2859,16 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec + def _get_sparse_kv_cache_ratio(self) -> list[int]: + # TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes + # when calculating the ratio,for example: + # [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...] + return [ + self.model_config.hf_text_config.kv_lora_rank, + self.model_config.hf_text_config.qk_rope_head_dim, + self.model_config.hf_text_config.index_head_dim, + ] + def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]],