Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
# To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
Expand Down
26 changes: 26 additions & 0 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,32 @@ def test_models_distributed_QwQ():
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_models_distributed_DeepSeek_multistream_moe():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend="mp",
additional_config={
"torchair_graph_config": {
"enabled": True,
"enable_multistream_moe": True,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
},
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_models_distributed_DeepSeek():
example_prompts = [
"Hello, my name is",
Expand Down
8 changes: 5 additions & 3 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
Expand Down Expand Up @@ -557,6 +558,7 @@
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()

Check warning on line 561 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L561

Added line #L561 was not covered by tests

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
Expand Down Expand Up @@ -586,7 +588,7 @@
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
return self.o_proj(x, is_prefill=False)[0]

Check warning on line 591 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L591

Added line #L591 was not covered by tests

# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
Expand Down Expand Up @@ -847,12 +849,12 @@

current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]

Check warning on line 852 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L852

Added line #L852 was not covered by tests
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]

Check warning on line 857 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L857

Added line #L857 was not covered by tests

def exec_kv(
self,
Expand Down
17 changes: 9 additions & 8 deletions vllm_ascend/models/deepseek_dbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
ReplicatedLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
Expand All @@ -64,7 +63,8 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP,
CustomDeepseekV2RowParallelLinear)
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.context import (
advance_step_multistream_layer_context, get_multistream_comm_context,
Expand Down Expand Up @@ -325,11 +325,12 @@
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.o_proj = CustomDeepseekV2RowParallelLinear(

Check warning on line 328 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L328

Added line #L328 was not covered by tests
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")

if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
Expand Down
Loading
Loading