Skip to content

[Feat] flashcomm_v2 optim solution#3232

Merged
weijinqian0 merged 15 commits intovllm-project:mainfrom
coder-fny:official-fc2
Nov 10, 2025
Merged

[Feat] flashcomm_v2 optim solution#3232
weijinqian0 merged 15 commits intovllm-project:mainfrom
coder-fny:official-fc2

Conversation

@Levi-JQ
Copy link
Copy Markdown
Contributor

@Levi-JQ Levi-JQ commented Sep 28, 2025

What this PR does / why we need it?

Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1.

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the FlashComm2 optimization for tensor parallelism on Ascend NPUs, aiming to improve performance by optimizing communication patterns. The changes span configuration, parallel state management, and operator implementations. My review has identified a few issues: a critical bug in the parallel group initialization that can lead to a crash, a related potential resource leak in the group destruction logic, and incorrect formatting of error messages in the configuration validation. These issues should be addressed to ensure correctness and robustness.

_FLASHCOMM2_OTP = None
_FLASHCOMM2_ODP = get_tp_group()

if flashcomm2_otp_size > 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The process group creation for FlashComm2 is guarded by if flashcomm2_otp_size > 1:. This causes _FLASHCOMM2_OTP to be None when flashcomm2_oproj_tensor_parallel_size is 1. However, Flashcomm2OProjRowParallelOp is still used in this case, and it attempts to access methods on the _FLASHCOMM2_OTP group, which will lead to a crash. The logic within this if block appears to correctly handle the size == 1 case by creating groups of size 1. The conditional guard should be removed, and its content unindented, to fix this critical bug.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Flashcomm2OProjRowParallelOp that uses _FLASHCOMM2_OTP , a check has been added to determine whether flashcomm2_oproj_tensor_parallel_size is 1 to avoid errors. By the way,This approach of setting it to None avoids redundant communication groups when flashcomm2_oproj_tensor_parallel_sizeis 1, reducing buffer consumption.

Comment on lines +107 to +113
raise AssertionError(
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
raise AssertionError(
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error message strings are not f-strings, so the variables inside the curly braces will not be interpolated. This will result in confusing and unhelpful error messages for users.

Suggested change
raise AssertionError(
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
raise AssertionError(
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
)
raise AssertionError(
f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
raise AssertionError(
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

_OTP = None

global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1 will prevent the _FLASHCOMM2_OTP group from being destroyed when its size is 1. If the initialization logic is fixed to create a group for size 1 (as suggested in another comment), this will cause a resource leak. The group should be destroyed if it was created, regardless of its size.

Suggested change
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
if _FLASHCOMM2_OTP:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_FLASHCOMM2_OTP is None when get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1

@Levi-JQ Levi-JQ force-pushed the official-fc2 branch 5 times, most recently from 5b6c013 to 1b8cdb3 Compare September 30, 2025 02:38
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@Levi-JQ Levi-JQ force-pushed the official-fc2 branch 2 times, most recently from 693c0ec to 7541e94 Compare October 9, 2025 13:02
@Levi-JQ Levi-JQ changed the title [main] flashcomm_v2 optim solution [Feat] flashcomm_v2 optim solution Oct 16, 2025
@Levi-JQ Levi-JQ force-pushed the official-fc2 branch 3 times, most recently from 0f881a4 to d990a18 Compare October 16, 2025 10:17
@zzhx1 zzhx1 force-pushed the official-fc2 branch 2 times, most recently from 834dd41 to 84eacfa Compare October 16, 2025 12:12
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Levi-JQ and others added 5 commits November 5, 2025 10:41
2.Rename the environment variable VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
3.Normalize the enabling logic for sp/fc2
4.add TODO: Normalize the communication domain
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Signed-off-by: zzhxx <2783294813@qq.com>
@Levi-JQ Levi-JQ force-pushed the official-fc2 branch 2 times, most recently from 2ece362 to 9229303 Compare November 5, 2025 03:03
@yiz-liu yiz-liu added ready read for review ready-for-test start test by label for PR labels Nov 5, 2025
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
@zzhx1
Copy link
Copy Markdown
Contributor

zzhx1 commented Nov 6, 2025

@wangxiyuan Check if this PR can be merged.

@Levi-JQ Levi-JQ requested a review from zzzzwwjj November 7, 2025 06:14
@zzzzwwjj
Copy link
Copy Markdown
Collaborator

zzzzwwjj commented Nov 8, 2025

Suggest raising an issue to organize the roles and relationships between flashcomm, flashcomm2, and enable_sp features, and comm domain reusing of lm_head_tp.

@zzhx1
Copy link
Copy Markdown
Contributor

zzhx1 commented Nov 10, 2025

Suggest raising an issue to organize the roles and relationships between flashcomm, flashcomm2, and enable_sp features, and comm domain reusing of lm_head_tp.

Within this week, we will raise an issue to clarify the connections between these features,
The communication domain of lmheadTP cannot be reused, as the use cases are different, one is prefill and the other is decode.

@weijinqian0 weijinqian0 merged commit 0a62e67 into vllm-project:main Nov 10, 2025
24 checks passed
luolun pushed a commit to luolun/vllm-ascend that referenced this pull request Nov 19, 2025
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
Signed-off-by: luolun <luolun1995@cmbchina.com>
hwhaokun pushed a commit to hwhaokun/vllm-ascend that referenced this pull request Nov 19, 2025
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
Signed-off-by: hwhaokun <haokun0405@163.com>
Levi-JQ pushed a commit to coder-fny/vllm-ascend that referenced this pull request Nov 20, 2025
Merge branch pick-fc2-1111 of git@code.alipay.com:Theta/vllm-ascend.git into dev-v0.11.0.1111
https://code.alipay.com/Theta/vllm-ascend/pull_requests/594?tab=diff

Reviewed-by: 沧濯 <zhengshoujian.zsj@antgroup.com>


* [Feat] flashcomm_v2 optim solution (vllm-project#3232)
NSDie pushed a commit to NSDie/vllm-ascend that referenced this pull request Nov 24, 2025
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
Signed-off-by: nsdie <yeyifan@huawei.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 10, 2025
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
vllm-project/vllm@83f478b

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants