Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_sp_for_qwen3_moe() -> None:
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
example_prompts = [
"Hello, my name is",
Expand Down
5 changes: 4 additions & 1 deletion tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,12 @@ def test_forward_decode_without_graph(self,
mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once()

@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch):
def test_mla_preprocess(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad):
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
batch_size = 4
seq_len = 8
hidden_size = 1024
Expand Down
5 changes: 5 additions & 0 deletions tests/ut/models/test_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed):
assert output[0].shape == (2, 4, 64)


@patch("vllm_ascend.models.layers.mla.get_forward_context")
@patch("torch.ops.vllm.mla_forward")
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_forward_context,
mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
# Make a fake ascend config because of the AscendLinearBase
Expand All @@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
dummy_forward_context = MagicMock()
dummy_forward_context.sp_enabled = False
mock_forward_context.return_value = dummy_forward_context

attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
Expand Down
23 changes: 17 additions & 6 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
set_forward_context)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
from vllm_ascend.utils import enable_sp, is_moe_model

if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
Expand Down Expand Up @@ -112,15 +112,20 @@ def set_ascend_forward_context(
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1
else:
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000

if sp_enabled:
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
forward_context.sp_enabled = sp_enabled
forward_context.num_tokens = num_tokens

# set this for rope forward_oot using
forward_context.is_first_layer = True
Expand Down Expand Up @@ -169,8 +174,14 @@ def set_ascend_forward_context(

dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
)
max_tokens_across_dp = \
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if sp_enabled:
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else:
max_tokens_across_dp = num_tokens

Expand Down
27 changes: 11 additions & 16 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
AttentionMetadata,
MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
Expand Down Expand Up @@ -1128,10 +1128,11 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
q_c = hidden_states

kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for shared_expert_dp
if need_gather_q_kv:
q_c = get_tp_group().all_gather(q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c, need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
Expand Down Expand Up @@ -1200,8 +1201,7 @@ def forward(
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_tokens, ...]
o_proj_input_shape = (num_actual_tokens,
o_proj_input_shape = (get_forward_context().num_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
Expand Down Expand Up @@ -1248,7 +1248,8 @@ def forward(
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[num_decode_tokens:] = output_prefill
o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
Expand All @@ -1258,20 +1259,14 @@ def forward(
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)

output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
output[...] = self.o_proj(o_proj_input)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
output[...] = self.o_proj(o_proj_input)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input

Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
# Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
Expand Down
11 changes: 5 additions & 6 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,11 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")

if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
Expand Down
15 changes: 2 additions & 13 deletions vllm_ascend/models/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,8 @@ def forward(
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
need_gather_q_kv = get_forward_context().sp_enabled
output_shape = hidden_states.shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
Expand Down
12 changes: 9 additions & 3 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz,
npu_stream_switch)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
is_enable_nz, npu_stream_switch,
shared_expert_dp_enabled)


class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
Expand Down Expand Up @@ -400,6 +401,10 @@ def __init__(
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream()
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
)

def forward(
self,
Expand Down Expand Up @@ -427,7 +432,8 @@ def forward_impl(self, hidden_states: torch.Tensor,
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_output = AscendFusedMoE.forward_impl(
self,
Expand Down
5 changes: 3 additions & 2 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
oproj_tp_enable, shared_expert_dp_enabled)


class CustomLinearOp:
Expand Down Expand Up @@ -418,7 +418,8 @@ def _get_row_parallel_op(


def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp:
if disable_tp or ("shared_experts" in prefix
and shared_expert_dp_enabled()):
return None, 0, 1
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
MLPRowParallelOp, OProjRowParallelOp,
Expand Down
Loading
Loading