Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
"v_head_dim": 128,
"rotary_emb": MagicMock(),
"q_proj": MagicMock(),
"q_b_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
Expand Down
16 changes: 1 addition & 15 deletions tests/ut/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,11 @@ def mock_distributed():
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)

with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
with patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
patch("torch.npu.current_device", return_value=0):
yield


@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
return_value=forward_context):
yield
2 changes: 0 additions & 2 deletions tests/ut/models/test_deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def setup_mtp_layer(self, mocker: MockerFixture, vllm_config: VllmConfig,
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.models.deepseek_v2.get_ascend_config",
return_value=mocker.Mock())

mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "0", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
Expand Down
130 changes: 0 additions & 130 deletions tests/ut/models/test_deepseek_v2.py

This file was deleted.

1 change: 1 addition & 0 deletions tests/ut/torchair/test_torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
"v_head_dim": 128,
"rotary_emb": MagicMock(),
"q_proj": MagicMock(),
"q_b_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
Expand Down
72 changes: 46 additions & 26 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,13 @@ def __init__(
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj']
self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None)
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
'q_b_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_proj = kwargs.get('q_a_proj', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -648,36 +649,46 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
self._process_weights_for_fused_mlapo(act_dtype)

def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()
kv_a_proj_wt = kv_a_proj_wt.contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1)
kv_a_proj_wt = kv_a_proj_wt.contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)

kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
self.q_lora_rank:].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
q_lora_rank].contiguous(
)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat(
(kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous()

kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
dim=-1).contiguous()

kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
self.q_lora_rank:].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
q_lora_rank].contiguous(
)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat(
(kv_a_proj_qt_bias, self.q_a_proj.quant_bias),
dim=-1).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
dim=-1).contiguous()

wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads,
Expand All @@ -704,22 +715,22 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))

device = self.q_a_proj.weight.device
device = self.q_proj.weight.device
self.gamma0 = torch.ones(
[self.q_a_proj.weight.shape[-1]],
[self.fused_qkv_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.beta0 = torch.zeros(
[self.q_a_proj.weight.shape[-1]],
[self.fused_qkv_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.quant_scale0 = self.q_a_proj.input_scale.data
self.quant_offset0 = self.q_a_proj.input_offset.data
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
Expand Down Expand Up @@ -1122,21 +1133,26 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if self.q_a_proj is not None:
maybe_npu_prefetch(inputs=self.q_a_proj.weight,
if self.fused_qkv_a_proj is not None:
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
ckq = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(ckq)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
else:
q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]

kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c, need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)

decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
Expand Down Expand Up @@ -1264,14 +1280,18 @@ def forward(
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)

output[...] = self.o_proj(o_proj_input)[0]
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input)[0]
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input

Expand Down
8 changes: 0 additions & 8 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
)

ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")

ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")

ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
Expand Down
Loading
Loading