Skip to content

Conversation

@lbk-sys
Copy link
Contributor

@lbk-sys lbk-sys commented Aug 5, 2025

What this PR does / why we need it?

Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2, replacing AllReduce with Reduce-Scatter and AllGather achieves computational benefits in norm operations while saving one AllGather communication. This feature is enabled during the P-phase and delivers notable gains in long-sequence scenarios (e.g., 16k–25k), with performance improvements reaching 5%–10%.

Does this PR introduce any user-facing change?

How was this patch tested?

compilation_config={
    "pass_config":{
        "enable_sequence_parallelism": True
    }
},
enable_expert_parallel=True,

Signed-off-by: libaokui <[email protected]>
@github-actions
Copy link

github-actions bot commented Aug 5, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
get_tp_group, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use current_platform in vllm-ascend. import vllm_ascend.platform directlly

)
self.mc2_mask[:lengths_sum_unpadding] = True

def padding_aligned_reduce_scatter(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functions are duplicated with the pad and unpad functions in flashcommv1,can we aggregating them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. We chose not to adopt the flashcomm1 implementation from the 091 branch for two reasons:

  1. The existing flashcomm1 implementations for Qwen2 and Qwen3 in the repository are inconsistent. We've created a model-level interface here to minimize migration efforts for sparse models (SP).

  2. Currently, flashcomm1's graph mode support for sparse models like Qwen2 and Qwen3 isn't available in the main branch. Merging it would impact graph mode performance, so we're keeping it separate for now. Note that merging Qwen3 MoE's SP implementation won't affect the current status.

Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
@github-actions
Copy link

github-actions bot commented Aug 5, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
@github-actions
Copy link

github-actions bot commented Aug 6, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
@xueliangyang-oeuler
Copy link

xueliangyang-oeuler commented Aug 6, 2025

@lbk-sys can you list out your testing CLI, e.g, vllm serve, and your testing datasets? Thanks.

Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
@codecov
Copy link

codecov bot commented Aug 6, 2025

Codecov Report

❌ Patch coverage is 19.84733% with 105 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.09%. Comparing base (e31b31f) to head (c76cd00).
⚠️ Report is 621 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/ops/sequence_parallel.py 17.91% 55 Missing ⚠️
vllm_ascend/models/qwen3_moe.py 14.00% 43 Missing ⚠️
vllm_ascend/ops/fused_moe.py 44.44% 5 Missing ⚠️
vllm_ascend/platform.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2209      +/-   ##
==========================================
- Coverage   76.65%   76.09%   -0.56%     
==========================================
  Files         113      114       +1     
  Lines       12763    13103     +340     
==========================================
+ Hits         9783     9971     +188     
- Misses       2980     3132     +152     
Flag Coverage Δ
unittests 76.09% <19.84%> (-0.56%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: libaokui <[email protected]>
@lbk-sys
Copy link
Contributor Author

lbk-sys commented Aug 6, 2025

@lbk-sys can you list out your testing CLI, e.g, vllm serve, and your testing datasets? Thanks.

Thank you for your attention. For testing accuracy, we have run the AIME dataset. For performance testing, we have conducted benchmarks using vLLM and offline scripts. The input data for the model follows the format t*h, so as long as the number of input tokens in the P-phase meets this requirement, benefits can be achieved (regardless of the dataset).

@lbk-sys
Copy link
Contributor Author

lbk-sys commented Aug 6, 2025

@lbk-sys can you list out your testing CLI, e.g, vllm serve, and your testing datasets? Thanks.

Thank you for your attention. For testing accuracy, we have run the AIME dataset. For performance testing, we have conducted benchmarks using vLLM and offline scripts. The input data for the model follows the format t*h, so as long as the number of input tokens in the P-phase meets this requirement, benefits can be achieved (regardless of the dataset).

and deepscaler dataset

with VllmRunner(
snapshot_download("Qwen/Qwen3-30B-A3B"),
dtype="auto",
tensor_parallel_size=4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is just 2 cards on CI machine, let's reduce tp size to 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done , thanks

Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
Signed-off-by: libaokui <[email protected]>
@wangxiyuan wangxiyuan merged commit c611291 into vllm-project:main Aug 7, 2025
25 checks passed
zzhx1 pushed a commit to lidenghui1110/vllm-ascend that referenced this pull request Aug 11, 2025
### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
``` 
compilation_config={
    "pass_config":{
        "enable_sequence_parallelism": True
    }
},
enable_expert_parallel=True,
```

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@9edd1db

---------

Signed-off-by: libaokui <[email protected]>
Co-authored-by: libaokui <[email protected]>
zzhx1 pushed a commit to lidenghui1110/vllm-ascend that referenced this pull request Aug 11, 2025
### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
``` 
compilation_config={
    "pass_config":{
        "enable_sequence_parallelism": True
    }
},
enable_expert_parallel=True,
```

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@9edd1db

---------

Signed-off-by: libaokui <[email protected]>
Co-authored-by: libaokui <[email protected]>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
``` 
compilation_config={
    "pass_config":{
        "enable_sequence_parallelism": True
    }
},
enable_expert_parallel=True,
```

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@9edd1db

---------

Signed-off-by: libaokui <[email protected]>
Co-authored-by: libaokui <[email protected]>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
``` 
compilation_config={
    "pass_config":{
        "enable_sequence_parallelism": True
    }
},
enable_expert_parallel=True,
```

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@9edd1db

---------

Signed-off-by: libaokui <[email protected]>
Co-authored-by: libaokui <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants