Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 35 additions & 35 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
FUSED_ALLTOALL = 3
FUSED_MC2 = 3


@contextmanager
Expand Down Expand Up @@ -62,11 +62,8 @@ def set_ascend_forward_context(

from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
# TODO: remove this after moe_comm_type selection logic is finalized
if is_mtp_model:
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
is_mtp_model)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)

Expand All @@ -93,7 +90,7 @@ def set_ascend_forward_context(
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
) and tp_world_size > 1 and num_tokens is not None

Expand Down Expand Up @@ -210,29 +207,30 @@ def get_mc2_mask():


def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
number of tokens. This is based on the observation that all-gather is more
efficient than all-to-all when running on A2.

a. For A2, we choose from MC2 and all-gather.

b. For A3, we choose from MC2 and all-to-all.

In both cases, we use MC2 when the number of tokens is smaller than
a its capacity threshold.

Args:
num_tokens (int): The number of tokens in the current batch.

Raises:
ValueError: If the soc version is unsupported.

Returns:
MoECommType: The selected MoE communication method.
"""
vllm_config: VllmConfig,
is_mtp_model=False) -> Optional[MoECommType]:
"""Select the MoE communication method according to parallel settings,
device generation, token count, and quantization.

1. Non-MoE models return `None`.
2. Without expert parallel, fall back to all-gather.
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
and the DP size is large enough; otherwise use all-gather.
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.

Args:
num_tokens (int): The number of tokens in the current batch.
vllm_config (VllmConfig): Runtime configuration for the model.
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2).

Raises:
ValueError: If the soc version is unsupported.

Returns:
MoECommType | None: The selected MoE communication method.
"""
if not is_moe_model(vllm_config):
return None
mc2_tokens_capacity = get_mc2_tokens_capacity()
Expand All @@ -255,11 +253,13 @@ def select_moe_comm_method(num_tokens: int,
ascend_config = get_ascend_config()
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb)
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
else MoECommType.FUSED_ALLTOALL
if fused_all2all_enable else MoECommType.ALLTOALL)
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change is_mtp_model to check model type

if num_tokens <= mc2_tokens_capacity:
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
else:
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL

else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@
# Whether to anbale dynamic EPLB
"DYNAMIC_EPLB":
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
}

# end-env-vars-definition
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
Expand Down
64 changes: 35 additions & 29 deletions vllm_ascend/ops/fused_moe/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
Expand All @@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
moe_config)
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)


class MoECommMethod(ABC):
Expand Down Expand Up @@ -241,30 +241,27 @@ def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)


class FusedAlltoAllCommImpl(MoECommMethod):
class FusedMC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.

This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.

This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""

def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
return TokenDispatcherWithMC2()

def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
return PrepareAndFinalizeWithMC2(self.moe_config)

def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
Expand All @@ -274,8 +271,8 @@ def fused_experts(
use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
Comment thread
kiscad marked this conversation as resolved.
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
Expand All @@ -291,18 +288,27 @@ def fused_experts(
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
assert not (
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
out = torch.empty_like(hidden_states)

torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1,
weight2=w2,
expert_idx=topk_ids,
scale1=w1_scale,
scale2=w2_scale,
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1[0],
weight2=w2[0],
expert_idx=topk_ids,
scale1=w1_scale[0],
scale2=w2_scale[0],
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
raise NotImplementedError()
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return out
12 changes: 7 additions & 5 deletions vllm_ascend/ops/register_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,

with torch.npu.stream(prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, mlp_gate_up_prefetch_size)
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight,
x_dependency, mlp_gate_up_prefetch_size)
return


Expand Down Expand Up @@ -185,8 +186,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:

with torch.npu.stream(prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, mlp_down_prefetch_size)
torch_npu.npu_prefetch(
model_instance.model.layers[layer_idx].mlp.down_proj.weight,
x_dependency, mlp_down_prefetch_size)
forward_context.layer_idx += 1
return

Expand Down Expand Up @@ -250,7 +252,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2
} or forward_context.sp_enabled:
return final_hidden_states
else:
Expand Down
14 changes: 8 additions & 6 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
Expand Down Expand Up @@ -246,15 +247,16 @@ def apply(
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]

fused_flag = get_forward_context(
).moe_comm_type == MoECommType.FUSED_ALLTOALL
fused_scale_flag = (get_forward_context().moe_comm_type
== MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
return moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1[0] if fused_flag else w1,
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
w2=w2[0] if fused_flag else w2,
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
w1=w1,
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w2=w2,
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
Expand Down
7 changes: 4 additions & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def _skip_all_reduce_acorss_dp_group(self) -> bool:
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
# nodes. So here we check whether recompute_scheduler_enable is True.
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
potential_max_num_tokens, self.vllm_config) == MoECommType.MC2
potential_max_num_tokens,
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}

def _sync_metadata_across_dp(
self, num_tokens: int,
Expand Down Expand Up @@ -1057,7 +1058,7 @@ def _prepare_inputs(
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs+1] - \
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \
self.query_start_loc_pcp_full.cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
Expand Down Expand Up @@ -2200,7 +2201,7 @@ def _dummy_sampler_run(
def profile_run(self) -> None:
mc2_tokens_capacity = get_mc2_tokens_capacity()
if self.max_num_tokens > mc2_tokens_capacity and \
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2:
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}:
self._dummy_run(mc2_tokens_capacity,
with_prefill=True,
is_profile=True)
Expand Down
Loading