Skip to content
1 change: 1 addition & 0 deletions tests/singlecard/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,6 @@ def test_ascend_config_load_error():
},
}
with VllmRunner("facebook/opt-125m",
enforce_eager=False,
additional_config=input_additional_config_fake_2):
pass
4 changes: 3 additions & 1 deletion vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, torchair_graph_config):
"graph_batch_sizes", [])
self.graph_batch_sizes_init = torchair_graph_config.get(
"graph_batch_sizes_init", False)
self.enable_multistream_shared_expert = torchair_graph_config.get(
"enable_multistream_shared_expert", False)

if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
Expand Down Expand Up @@ -105,7 +107,7 @@ def check_ascend_config(vllm_config, enforce_eager):
ascend_config = get_ascend_config()

# Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode.
if ascend_config.torchair_graph_config.enabled and not enforce_eager:
if ascend_config.torchair_graph_config.enabled and enforce_eager:
raise RuntimeError(
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
)
Expand Down
21 changes: 19 additions & 2 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def __init__(

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert

def forward(
self,
Expand All @@ -238,6 +240,8 @@ def forward(

num_tokens, hidden_size = hidden_states.shape

multistream = self.enable_multistream_shared_expert and not is_prefill

old_hidden_states = hidden_states.clone()

if self.tp_size > 1:
Expand All @@ -259,13 +263,25 @@ def forward(
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

kwargs = {}
if multistream:
kwargs.update({
"shared_experts": self.shared_experts,
"shared_hidden_states": old_hidden_states
})

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor
**kwargs)

if multistream:
hidden_states, shared_output = hidden_states

hidden_states = hidden_states * self.routed_scaling_factor

if self.tp_size > 1:
if self.torchair_graph_enabled:
Expand All @@ -288,7 +304,8 @@ def forward(
hidden_states = hidden_states[:-num_padding_tokens]

if self.n_shared_experts is not None:
shared_output = self.shared_experts(old_hidden_states)
if not multistream:
shared_output = self.shared_experts(old_hidden_states)

if shared_output is not None:
hidden_states = hidden_states + shared_output
Expand Down
47 changes: 28 additions & 19 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,18 @@
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM


def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
) -> torch.Tensor:
def fused_experts_with_mc2(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
**kwargs) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
kwargs = {
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
Expand Down Expand Up @@ -81,9 +80,9 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
kwargs_mc2.update(stage1_kwargs)

output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
Expand Down Expand Up @@ -119,7 +118,7 @@ def fused_experts_with_mc2(
down_out_list = torch.cat(down_out_list, dim=0)

# moeCombine
kwargs = {
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
Expand All @@ -141,9 +140,9 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)

return hidden_states

Expand Down Expand Up @@ -675,7 +674,8 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs)
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -772,6 +772,8 @@ def __init__(

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
Expand Down Expand Up @@ -818,7 +820,8 @@ def forward(self,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k=None):
top_k=None,
**kwargs):
assert self.quant_method is not None

if top_k:
Expand Down Expand Up @@ -862,7 +865,11 @@ def forward(self,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance)
enable_force_load_balance=enable_force_load_balance,
**kwargs)

if self.enable_multistream_shared_expert and not is_prefill:
hidden_states, shared_output = hidden_states

if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
Expand All @@ -886,4 +893,6 @@ def forward(self,
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)

if self.enable_multistream_shared_expert and not is_prefill:
return hidden_states, shared_output
return hidden_states
2 changes: 1 addition & 1 deletion vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance)
is_prefill, enable_force_load_balance, **kwargs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
Expand Down
103 changes: 81 additions & 22 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import GroupCoordinator
import torchair as tng # type: ignore
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
Expand All @@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
group_list_type: int = 1,
**kwargs) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj

Expand Down Expand Up @@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor,
else:
pertoken_scale = dynamic_scale

shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_gate_up = kwargs.get('shared_gate_up', None)
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=shared_gate_up,
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
activation_scale=shared_dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)

# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
Expand Down Expand Up @@ -100,25 +119,39 @@ def apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]

if shared_experts:
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_x, hidden_states)
shared_output = torch_npu.npu_quant_matmul(
shared_x,
shared_experts.down_proj.weight,
shared_experts.down_proj.weight_scale,
pertoken_scale=shared_dynamic_scale,
output_dtype=torch.bfloat16,
)
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
shared_output = tensor_model_parallel_all_reduce(shared_output)
if shared_experts:
return hidden_states, shared_output
return hidden_states


def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
) -> torch.Tensor:
def fused_experts_with_mc2(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
**kwargs) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
# hidden_states = hidden_states.bfloat16()
kwargs = {
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
Expand Down Expand Up @@ -149,9 +182,27 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
kwargs_mc2.update(stage1_kwargs)

shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_hidden_states = kwargs.get('shared_hidden_states', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
shared_hidden_states)
shared_gate_up = torch_npu.npu_quant_matmul(
shared_x,
shared_experts.gate_up_proj.weight,
shared_experts.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
kwargs.update({
"shared_gate_up": shared_gate_up,
"shared_dynamic_scale": shared_dynamic_scale,
})

output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
Expand All @@ -166,10 +217,15 @@ def fused_experts_with_mc2(
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
dynamic_scale=dynamic_scale,
**kwargs)

multi_stream = isinstance(down_out_list, tuple)
if multi_stream:
down_out_list, shared_output = down_out_list

# moeCombine
kwargs = {
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
Expand All @@ -193,10 +249,12 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)

if multi_stream:
return hidden_states, shared_output
return hidden_states


Expand Down Expand Up @@ -634,7 +692,8 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs)
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down