Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions vllm_ascend/torchair/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
from vllm_ascend.utils import vllm_version_is


class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
Expand Down Expand Up @@ -313,10 +314,11 @@ def __init__(
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.enable_sp
if vllm_config is not None else False)
self.enable_sequence_parallelism = False
if vllm_config is not None:
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism if vllm_version_is(
"0.12.0"
) else vllm_config.compilation_config.pass_config.enable_sp

def forward(
self,
Expand Down Expand Up @@ -488,7 +490,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sp
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change appears to be incorrect. It replaces enable_sp with enable_sequence_parallelism without the necessary compatibility logic for different vLLM versions. This will likely cause a failure on vLLM versions newer than 0.12.0. To ensure compatibility, you should apply the same version-checking logic that is used in CustomQwen3MoeDecoderLayer.__init__ and vllm_ascend/utils.py.

Suggested change
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.enable_sequence_parallelism
if vllm_version_is("0.12.0") else vllm_config.compilation_config.pass_config.enable_sp
)

# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []

Expand Down
5 changes: 3 additions & 2 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,10 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
if vllm_config is None:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
vllm_enable_sp = vllm_config.compilation_config.pass_config.enable_sequence_parallelism if vllm_version_is(
"0.12.0") else vllm_config.compilation_config.pass_config.enable_sp
_ENABLE_SP = (
vllm_config.compilation_config.pass_config.enable_sp
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
vllm_enable_sp or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
Expand Down
Loading