diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py index 8338946f132..67155a86094 100644 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -56,6 +56,7 @@ from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, init_metadata_for_sp) from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE +from vllm_ascend.utils import vllm_version_is class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -313,10 +314,11 @@ def __init__( eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.enable_sequence_parallelism = ( - vllm_config.compilation_config.pass_config.enable_sp - if vllm_config is not None else False) + self.enable_sequence_parallelism = False + if vllm_config is not None: + self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism if vllm_version_is( + "0.12.0" + ) else vllm_config.compilation_config.pass_config.enable_sp def forward( self, @@ -488,7 +490,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sp + self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism # Set MoE hyperparameters self.expert_weights: list[torch.Tensor] = [] diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 00014c802c3..b7d7ab9cac4 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -752,9 +752,10 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool: if vllm_config is None: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() + vllm_enable_sp = vllm_config.compilation_config.pass_config.enable_sequence_parallelism if vllm_version_is( + "0.12.0") else vllm_config.compilation_config.pass_config.enable_sp _ENABLE_SP = ( - vllm_config.compilation_config.pass_config.enable_sp - or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 + vllm_enable_sp or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))