Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
Expand Down
24 changes: 24 additions & 0 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,27 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC():
},
) as vllm_model:
vllm_model.generate_greedy(prompts, max_tokens)


def test_sp_for_qwen3_moe() -> None:
example_prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)

with VllmRunner(
snapshot_download("Qwen/Qwen3-30B-A3B"),
dtype="auto",
tensor_parallel_size=4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is just 2 cards on CI machine, let's reduce tp size to 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done , thanks

distributed_executor_backend="mp",
compilation_config={
"pass_config": {
"enable_sequence_parallelism": True
}
},
enable_expert_parallel=True,
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
1 change: 1 addition & 0 deletions tests/ut/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def setUp(self):
self.mock_vllm_config.cache_config = MagicMock()
self.mock_vllm_config.scheduler_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False

self.mock_ascend_config = MagicMock()
self.mock_ascend_config.torchair_graph_config.enabled = False
Expand Down
7 changes: 5 additions & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class AscendMetadata:
slot_mapping: torch.Tensor = None

enable_dbo_across_dp: bool = False
is_only_prefill: bool = False


class AscendAttentionMetadataBuilder:
Expand All @@ -166,7 +167,8 @@ def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
enable_dbo_across_dp: bool = False):
enable_dbo_across_dp: bool = False,
is_only_prefill: bool = False):

block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
Expand Down Expand Up @@ -203,7 +205,8 @@ def build(self,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=enable_dbo_across_dp)
enable_dbo_across_dp=enable_dbo_across_dp,
is_only_prefill=is_only_prefill)
return attn_metadata


Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/attention/attention_v1_torchair.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def build(self,
num_actual_tokens,
max_query_len,
graph_pad_size: int = -1,
enable_dbo_across_dp: bool = False):
enable_dbo_across_dp: bool = False,
*args,
**kwargs):

device = self.runner.device

Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ def build(
graph_pad_size: int = -1,
query_start_loc: torch.Tensor = None,
enable_dbo_across_dp: bool = False,
*args,
**kwargs,
) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs

Expand Down
114 changes: 112 additions & 2 deletions vllm_ascend/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
# limitations under the License.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Optional

from typing import Optional, Union

import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
Expand All @@ -44,8 +45,11 @@
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors

from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)


class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
Expand Down Expand Up @@ -96,6 +100,7 @@ def forward(
self,
hidden_states,
attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
Expand All @@ -114,6 +119,7 @@ def forward(
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
)

return hidden_states
Expand Down Expand Up @@ -182,6 +188,57 @@ def __init__(
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.
enable_sequence_parallelism if vllm_config is not None else False)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:

# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True

# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)

hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)

if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)

hidden_states = self.mlp(hidden_states,
_metadata_for_padding=_metadata_for_padding)

return hidden_states, residual


@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
Expand Down Expand Up @@ -216,6 +273,45 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})

hidden_states, _ = self.norm(hidden_states, residual)

if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)

return hidden_states


class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
Expand Down Expand Up @@ -253,6 +349,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []

Expand All @@ -273,3 +370,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states
14 changes: 13 additions & 1 deletion vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
Expand Down Expand Up @@ -1347,7 +1348,8 @@ def forward(self,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False):
replace_allreduce: bool = False,
_metadata_for_padding: Optional[MetadataForPadding] = None):

assert self.quant_method is not None

Expand Down Expand Up @@ -1381,7 +1383,17 @@ def forward(self,
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)

mc2_mask = forward_context.mc2_mask

enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
tp_size = get_tensor_model_parallel_world_size()
if enable_sp:
tp_rank = get_tensor_model_parallel_rank()
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
replace_allreduce = True

if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
Expand Down
Loading
Loading