Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def register_model():

ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek_v2 is not supported now?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


ModelRegistry.register_model(
"DeepSeekMTPModel",
Expand Down
195 changes: 3 additions & 192 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,9 @@
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
AscendSparseFlashAttention, Indexer)
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase

Expand All @@ -84,16 +81,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config

self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device=current_platform.device_type)
else:
topk_indices_buffer = None
topk_indices_buffer = None

if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
Expand Down Expand Up @@ -332,7 +320,7 @@ def __init__(
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=None,
is_sparse=hasattr(config, "index_topk"),
is_sparse=False,
)

self.mla_attn = MultiHeadLatentAttention(
Expand Down Expand Up @@ -365,180 +353,6 @@ def forward(
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)


class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):

def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim

self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank

self.num_heads = num_heads
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace

self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])

ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp

if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj",
return_bias=False,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
return_bias=False,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
return_bias=False,
)

self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
return_bias=False,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
return_bias=False,
)
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
return_bias=False,
)

if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

self.dim: int = config.hidden_size # 7168
# TODO(zzzzwwjj): wait transformers add these params
self.n_heads: int = 64 # 64
self.head_dim: int = 128 # 128
self.index_topk: int = 2048 # 2048
self.indexer = Indexer(
config,
quant_config=quant_config,
dim=self.dim,
n_heads=self.n_heads,
head_dim=self.head_dim,
index_topk=self.index_topk,
prefix=f"{prefix}.indexer",
)

sfa_modules = AscendSFAModules(
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=self.indexer)

self.sfa_attn = AscendSparseFlashAttention(
self.hidden_size,
self.enable_shared_expert_dp,
self.debug_layer_idx,
self.first_k_dense_replace,
self.tp_size,
sfa_modules,
self.num_local_heads,
self.scaling,
self.layers,
self.kv_lora_rank,
self.qk_rope_head_dim,
self.q_lora_rank,
self.qk_nope_head_dim,
self.qk_head_dim,
self.v_head_dim,
cache_config,
quant_config,
prefix,
)
self.prefix = prefix

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)


class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):

def __init__(self,
Expand Down Expand Up @@ -566,10 +380,7 @@ def __init__(self,
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
if hasattr(model_config.hf_config, "index_topk"):
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = CustomDeepseekV2MLAAttention
attn_cls = CustomDeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
Expand Down
Empty file removed vllm_ascend/models/deepseek_v3.py
Empty file.
Loading
Loading