Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ exclude = [
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py",

# (11)
"vllm_ascend/ops/fused_moe/**",
]

[tool.ruff.lint]
Expand Down
43 changes: 17 additions & 26 deletions vllm_ascend/ops/fused_moe/comm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@
COMM_STREAM = None


def async_all_to_all(input_,
output_split_sizes,
input_split_sizes,
group,
event=None):
def async_all_to_all(input_, output_split_sizes, input_split_sizes, group, event=None):
if output_split_sizes is None:
# Equal split (all2all)
a2a_out = torch.empty_like(input_)
Expand All @@ -43,8 +39,7 @@ def async_all_to_all(input_,
# multi stream wait event
global COMM_STREAM
if COMM_STREAM is None:
COMM_STREAM = torch_npu.npu.Stream(
device=torch.npu.current_device())
COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
with torch_npu.npu.stream(COMM_STREAM):
event.wait()
handle = dist.all_to_all_single(
Expand All @@ -53,14 +48,17 @@ def async_all_to_all(input_,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
async_op=True,
)
else:
handle = dist.all_to_all_single(a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
handle = dist.all_to_all_single(
a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True,
)
return input_, a2a_out, handle


Expand All @@ -86,19 +84,12 @@ def _gather_along_first_dim(input_, group, output_split_sizes=None):
if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size

output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output, input_.contiguous(), group=group)
else:
dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
output_tensor_list = list(
torch.split(output, output_split_sizes, dim=0))
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
output_tensor_list = list(torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(output_tensor_list, input_, group=group)

return output
Expand All @@ -110,4 +101,4 @@ def gather_from_sequence_parallel_region(
output_split_sizes=None,
):
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
return _gather_along_first_dim(input_, group, output_split_sizes)
return _gather_along_first_dim(input_, group, output_split_sizes)
162 changes: 81 additions & 81 deletions vllm_ascend/ops/fused_moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional
from collections.abc import Callable

import torch

from vllm_ascend.utils import get_weight_prefetch_method


def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
global_num_experts: int = -1):
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: torch.Tensor | None = None,
indices_type: torch.dtype | None = None,
global_num_experts: int = -1,
):
"""
Fused experts with select experts.

Expand All @@ -58,16 +60,16 @@ def select_experts(hidden_states: torch.Tensor,
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
)

if is_support_npu_moe_gating_top_k:
topk_weights, topk_ids = _select_experts_with_fusion_ops(
Expand All @@ -81,7 +83,8 @@ def select_experts(hidden_states: torch.Tensor,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
else:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
Expand All @@ -100,54 +103,55 @@ def select_experts(hidden_states: torch.Tensor,


def check_npu_moe_gating_top_k(
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
custom_routing_function: Optional[Callable] = None):
if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
scoring_func: str = "softmax",
custom_routing_function: Callable | None = None,
):
if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch
return False
if custom_routing_function is not None:
return False
if scoring_func != "softmax" and scoring_func != "sigmoid":
return False
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
if not (
num_expert_group > 0
and hidden_states.shape[-1] % num_expert_group == 0
and hidden_states.shape[-1] // num_expert_group > 2
):
return False
if topk_group < 1 or topk_group > num_expert_group:
return False
if top_k < 1 or \
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
return False
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103
return False
return True


def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
num_expert_group: int | None,
topk_group: int | None,
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group

num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weight_mask = (
topk_group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group)
.reshape(num_token, -1)
)
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)

return topk_weights
Expand All @@ -163,9 +167,13 @@ def _renormalize_topk_weights(


def _select_expert_use_group_topk(
topk_weights: torch.Tensor, topk_group: Optional[int],
renormalize: bool, top_k: int, num_expert_group: Optional[int],
e_score_correction_bias: Optional[torch.Tensor]):
topk_weights: torch.Tensor,
topk_group: int | None,
renormalize: bool,
top_k: int,
num_expert_group: int | None,
e_score_correction_bias: torch.Tensor | None,
):
assert topk_group is not None
assert num_expert_group is not None

Expand All @@ -177,47 +185,38 @@ def _select_expert_use_group_topk(

# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids


def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int],
num_expert_group: Optional[int],
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1):

hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: torch.Tensor | None,
topk_group: int | None,
num_expert_group: int | None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1,
):
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
renorm = int(renormalize)
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and \
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
router_logits,
k=top_k,
Expand All @@ -228,7 +227,7 @@ def _select_experts_with_fusion_ops(
norm_type=norm_type, # 0: softmax; 1: sigmoid
out_flag=False,
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20),
eps=1e-20,
bias_opt=e_score_correction_bias,
)

Expand All @@ -241,12 +240,12 @@ def _native_select_experts(
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[torch.Tensor] = None
e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Expand Down Expand Up @@ -285,15 +284,17 @@ def _native_select_experts(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
)

if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
Expand All @@ -318,8 +319,7 @@ def zero_experts_compute(
if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales = torch.where(zero_expert_mask, 0.0,
zero_expert_scales)
zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales)

hidden_states = hidden_states.unsqueeze(1)
zero_expert_scales = zero_expert_scales.unsqueeze(2)
Expand Down
Loading
Loading