diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 43b88b0807c..d46b4a9fd47 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -284,6 +284,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index f4d879a9721..459d8269726 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -234,3 +234,27 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC(): }, ) as vllm_model: vllm_model.generate_greedy(prompts, max_tokens) + + +def test_sp_for_qwen3_moe() -> None: + example_prompts = [ + "Hello, my name is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner( + snapshot_download("Qwen/Qwen3-30B-A3B"), + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + compilation_config={ + "pass_config": { + "enable_sequence_parallelism": True + } + }, + enable_expert_parallel=True, + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index c22db8b2c53..e6642b5adaa 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -26,6 +26,7 @@ def setUp(self): self.mock_vllm_config.cache_config = MagicMock() self.mock_vllm_config.scheduler_config = MagicMock() self.mock_vllm_config.speculative_config = None + self.mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False self.mock_ascend_config = MagicMock() self.mock_ascend_config.torchair_graph_config.enabled = False diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 6f7473f6144..6e031ea4eb3 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -151,6 +151,7 @@ class AscendMetadata: slot_mapping: torch.Tensor = None enable_dbo_across_dp: bool = False + is_only_prefill: bool = False class AscendAttentionMetadataBuilder: @@ -166,7 +167,8 @@ def build(self, num_reqs, num_actual_tokens, max_query_len, - enable_dbo_across_dp: bool = False): + enable_dbo_across_dp: bool = False, + is_only_prefill: bool = False): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) @@ -203,7 +205,8 @@ def build(self, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=enable_dbo_across_dp) + enable_dbo_across_dp=enable_dbo_across_dp, + is_only_prefill=is_only_prefill) return attn_metadata diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index fe7eb9dde6c..1e55d510060 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -223,7 +223,9 @@ def build(self, num_actual_tokens, max_query_len, graph_pad_size: int = -1, - enable_dbo_across_dp: bool = False): + enable_dbo_across_dp: bool = False, + *args, + **kwargs): device = self.runner.device diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b2b3ad0e596..b09dbe11259 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -372,6 +372,8 @@ def build( graph_pad_size: int = -1, query_start_loc: torch.Tensor = None, enable_dbo_across_dp: bool = False, + *args, + **kwargs, ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index f3598cc6237..29ab6755250 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -16,14 +16,15 @@ # limitations under the License. # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. -from typing import Optional + +from typing import Optional, Union import torch from torch import nn from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, CompilationLevel, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context @@ -44,8 +45,11 @@ from vllm.model_executor.models.utils import ( PPMissingLayer, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, + init_metadata_for_sp) class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -96,6 +100,7 @@ def forward( self, hidden_states, attn_metadata=None, + _metadata_for_padding: Optional[MetadataForPadding] = None, ): if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata @@ -114,6 +119,7 @@ def forward( top_k=self.top_k, enable_force_load_balance=enable_force_load_balance, shared_experts=None, + _metadata_for_padding=_metadata_for_padding, ) return hidden_states @@ -155,14 +161,14 @@ def __init__( layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) - use_aclgraph = (vllm_config is not None - and vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and not vllm_config.model_config.enforce_eager) + self.use_aclgraph = (vllm_config is not None + and vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager) if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): - if not use_aclgraph: + if not self.use_aclgraph: # FIXME: custom sparse moe block doesn't work with aclgraph. self.mlp = CustomSparseMoeBlock(config=config, quant_config=quant_config, @@ -182,6 +188,60 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.enable_sequence_parallelism = ( + vllm_config.compilation_config.pass_config. + enable_sequence_parallelism if vllm_config is not None else False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> torch.Tensor: + + # To prevent precision issues during the decoder phase when only prefilling enables SP + if not self.enable_sequence_parallelism: + self.self_attn.o_proj.reduce_results = True + else: + self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True + + # Self Attention + if residual is None: + residual = hidden_states + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + residual = _metadata_for_padding.padding_slice(residual) + + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned( + hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter( + hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if not self.use_aclgraph: + hidden_states = self.mlp( + hidden_states, _metadata_for_padding=_metadata_for_padding) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + @support_torch_compile class CustomQwen3MoeModel(Qwen3MoeModel): @@ -216,6 +276,45 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + _metadata_for_padding=_metadata_for_padding) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned( + hidden_states) + + return hidden_states + class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { @@ -253,6 +352,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism # Set MoE hyperparameters self.expert_weights: list[torch.Tensor] = [] @@ -273,3 +373,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + _metadata_for_padding = init_metadata_for_sp( + input_ids, self.enable_sequence_parallelism) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, _metadata_for_padding) + return hidden_states diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 04d288b0631..b345c3e1a18 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -47,6 +47,7 @@ from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) +from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, @@ -1347,7 +1348,8 @@ def forward(self, top_k: Optional[int] = None, shared_experts: Optional[Any] = None, gate=None, - replace_allreduce: bool = False): + replace_allreduce: bool = False, + _metadata_for_padding: Optional[MetadataForPadding] = None): assert self.quant_method is not None @@ -1381,7 +1383,17 @@ def forward(self, # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce shared_hidden_states = shared_experts(hidden_states) + mc2_mask = forward_context.mc2_mask + + enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill tp_size = get_tensor_model_parallel_world_size() + if enable_sp: + tp_rank = get_tensor_model_parallel_rank() + mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask + chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0) + mc2_mask = chunk_mc2_mask[tp_rank] + replace_allreduce = True + if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast diff --git a/vllm_ascend/ops/sequence_parallel.py b/vllm_ascend/ops/sequence_parallel.py new file mode 100644 index 00000000000..bfd327b4105 --- /dev/null +++ b/vllm_ascend/ops/sequence_parallel.py @@ -0,0 +1,120 @@ +import torch +from torch.nn import functional as F +from vllm.distributed import (get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context + +from vllm_ascend.platform import NPUPlatform + + +class MetadataForPadding: + + def __init__(self, + padding_flag=False, + lengths_sum_padding=0, + lengths_sum_unpadding=0, + pad_size=0, + not_dummy_and_is_prefill=False): + self.padding_flag = padding_flag + self.not_dummy_and_is_prefill = not_dummy_and_is_prefill + + self.lengths_sum_padding = lengths_sum_padding + self.lengths_sum_unpadding = lengths_sum_unpadding + self.pad_size = pad_size + + self.tp_size = get_tp_group().world_size + self.tp_rank_in_group = get_tp_group().rank_in_group + + assert self.lengths_sum_padding % self.tp_size == 0 + self.slice_size = self.lengths_sum_padding // self.tp_size + + self.mc2_mask = torch.zeros( + self.lengths_sum_padding, + dtype=torch.bool, + device=NPUPlatform.device_type, + ) + self.mc2_mask[:lengths_sum_unpadding] = True + + def padding_aligned_reduce_scatter(self, + data: torch.Tensor) -> torch.Tensor: + if self.padding_flag: + pad_size = self.pad_size + padded_data = F.pad(data, (0, 0, 0, pad_size)) + else: + padded_data = data + padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter( + padded_data, 0) + + return padded_data_reduce_scatter + + def allgather_unpadding_aligned(self, + padded_data: torch.Tensor) -> torch.Tensor: + padded_data_allgather = tensor_model_parallel_all_gather( + padded_data, 0) + if self.padding_flag: + lengths_sum_unpadding = self.lengths_sum_unpadding + unpadding_data = padded_data_allgather[:lengths_sum_unpadding] + else: + unpadding_data = padded_data_allgather + return unpadding_data + + def padding_slice(self, data: torch.Tensor) -> torch.Tensor: + + padded_data = F.pad(data, (0, 0, 0, self.pad_size)) + start = self.tp_rank_in_group * self.slice_size + end = start + self.slice_size + slice_data = padded_data[start:end] + + return slice_data + + def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor: + if self.padding_flag: + pad_size = self.pad_size + padded_data = F.pad(data, (0, 0, 0, pad_size)) + else: + padded_data = data + # padded_data = data + padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0) + + padded_data_reduce_scatter = padded_data[self.tp_rank_in_group] + + return padded_data_reduce_scatter + + +def init_metadata_for_sp(input_ids, enable_sequence_parallelism): + if not enable_sequence_parallelism: + return MetadataForPadding(padding_flag=False, + not_dummy_and_is_prefill=False) + + is_perifll = 0 + attn_metadata = get_forward_context().attn_metadata + tp_size = get_tensor_model_parallel_world_size() + if attn_metadata is not None: + if hasattr(attn_metadata, + 'is_only_prefill') and attn_metadata.is_only_prefill: + is_perifll = 1 + if hasattr(attn_metadata, + 'num_prefills') and attn_metadata.num_prefills > 0: + is_perifll = 1 + + if is_perifll: + lengths_sum_unpadding = input_ids.shape[0] + lengths_sum_padding = ( + (lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size + if lengths_sum_unpadding == lengths_sum_padding: + padding_flag = False + else: + padding_flag = True + pad_size = lengths_sum_padding - lengths_sum_unpadding + _metadata_for_padding = MetadataForPadding( + lengths_sum_unpadding=lengths_sum_unpadding, + lengths_sum_padding=lengths_sum_padding, + padding_flag=padding_flag, + pad_size=pad_size, + not_dummy_and_is_prefill=True) + + return _metadata_for_padding + + return MetadataForPadding(padding_flag=False, + not_dummy_and_is_prefill=False) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index fa369ac10d0..eb7ea8276cc 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -195,6 +195,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: ascend_config.ascend_scheduler_config) vllm_config.scheduler_config = ascend_scheduler_config + if compilation_config.pass_config.enable_sequence_parallelism: + if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe": + raise NotImplementedError( + "For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV." + ) + # register Ascend CustomOp register_ascend_customop() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b2f730a1b65..de90fe35c94 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1143,6 +1143,10 @@ def _process_reqs( with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + + is_only_prefill = bool(np.all(num_valid_tokens != 1)) + extra_builder_kwargs['is_only_prefill'] = is_only_prefill + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens)