Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
Expand Down Expand Up @@ -415,15 +416,15 @@ def _load_w2(self,
expert_data.copy_(loaded_weight)


class AscendSharedFusedMoE(AscendFusedMoE):
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):

def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
AscendFusedMoE.__init__(self, **kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
Expand Down Expand Up @@ -452,7 +453,8 @@ def forward(
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
shared_out = tensor_model_parallel_all_reduce(shared_out)

fused_out = super().forward(
_, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
Expand All @@ -461,6 +463,16 @@ def forward(
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_out

def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
shared_output = torch.empty(1)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_output, fused_output


UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/patch/worker/patch_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_shared_fused_moe # noqa
21 changes: 0 additions & 21 deletions vllm_ascend/patch/worker/patch_common/patch_shared_fused_moe.py

This file was deleted.

4 changes: 3 additions & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):

from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
Expand All @@ -525,6 +526,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"LogitsProcessor": AscendLogitsProcessor,
"RMSNorm": AscendRMSNorm,
"FusedMoE": AscendFusedMoE,
"SharedFusedMoE": AscendSharedFusedMoE,
Comment thread
wangxiyuan marked this conversation as resolved.
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
}

Expand Down
Loading