Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/ut/ops/test_moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
TokenDispatchResult)

Expand Down
21 changes: 6 additions & 15 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.utils import (
enable_sp,
maybe_trans_nz,
Expand Down Expand Up @@ -235,22 +235,13 @@ def __init__(self, *args, **kwargs):
self.quant_type = self._get_quant_type()

def _get_quant_type(self) -> QuantType:
quant_method = self.quant_method
if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None:
return QuantType.NONE
quant_type = QuantType.NONE
method = getattr(self.quant_method, "quant_method", None)

method = quant_method.quant_method
if method is not None:
quant_type = getattr(method, "quant_type", QuantType.NONE)

if hasattr(method, "quant_type"):
from vllm_ascend.quantization.methods.base import QuantType as SchemeQuantType

scheme_quant_type = method.quant_type
if scheme_quant_type == SchemeQuantType.W8A8:
return QuantType.W8A8
elif scheme_quant_type == SchemeQuantType.W4A8:
return QuantType.W4A8

return QuantType.NONE
return quant_type

def update_expert_map(self, new_expert_map):
self._expert_map = new_expert_map
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/fused_moe/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
PrepareAndFinalizeWithAll2All,
PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2,
QuantType,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import (
MoETokenDispatcher,
TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
)
from vllm_ascend.quantization.methods.base import QuantType

_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}

Expand Down
8 changes: 1 addition & 7 deletions vllm_ascend/ops/fused_moe/prepare_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# This file is a part of the vllm-ascend project.

from abc import ABC, abstractmethod
from enum import Enum

import torch
import torch.distributed as dist
Expand All @@ -32,15 +31,10 @@

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable


class QuantType(Enum):
NONE = 0
W8A8 = 1
W4A8 = 2


class PrepareAndFinalize(ABC):
"""
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
Expand Down