diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index c8e0eef46a4..40d768af9fd 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -7,7 +7,7 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl) -from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType +from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult, TokenDispatchResult) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index a53a368e742..26c59953ed1 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -35,7 +35,7 @@ from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method -from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType +from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.utils import ( enable_sp, maybe_trans_nz, @@ -235,22 +235,13 @@ def __init__(self, *args, **kwargs): self.quant_type = self._get_quant_type() def _get_quant_type(self) -> QuantType: - quant_method = self.quant_method - if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None: - return QuantType.NONE + quant_type = QuantType.NONE + method = getattr(self.quant_method, "quant_method", None) - method = quant_method.quant_method + if method is not None: + quant_type = getattr(method, "quant_type", QuantType.NONE) - if hasattr(method, "quant_type"): - from vllm_ascend.quantization.methods.base import QuantType as SchemeQuantType - - scheme_quant_type = method.quant_type - if scheme_quant_type == SchemeQuantType.W8A8: - return QuantType.W8A8 - elif scheme_quant_type == SchemeQuantType.W4A8: - return QuantType.W4A8 - - return QuantType.NONE + return quant_type def update_expert_map(self, new_expert_map): self._expert_map = new_expert_map diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index e1c31520055..59721d1ca99 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -30,7 +30,6 @@ PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, - QuantType, ) from vllm_ascend.ops.fused_moe.token_dispatcher import ( MoETokenDispatcher, @@ -38,6 +37,7 @@ TokenDispatcherWithAllGather, TokenDispatcherWithMC2, ) +from vllm_ascend.quantization.methods.base import QuantType _MoECommMethods: dict[MoECommType | None, MoECommMethod] = {} diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index ce467cf6243..cced7ae6022 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -15,7 +15,6 @@ # This file is a part of the vllm-ascend project. from abc import ABC, abstractmethod -from enum import Enum import torch import torch.distributed as dist @@ -32,15 +31,10 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl +from vllm_ascend.quantization.methods.base import QuantType from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable -class QuantType(Enum): - NONE = 0 - W8A8 = 1 - W4A8 = 2 - - class PrepareAndFinalize(ABC): """ Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization