[Feat] flashcomm_v2 optim solution#3232
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Code Review
This pull request introduces the FlashComm2 optimization for tensor parallelism on Ascend NPUs, aiming to improve performance by optimizing communication patterns. The changes span configuration, parallel state management, and operator implementations. My review has identified a few issues: a critical bug in the parallel group initialization that can lead to a crash, a related potential resource leak in the group destruction logic, and incorrect formatting of error messages in the configuration validation. These issues should be addressed to ensure correctness and robustness.
| _FLASHCOMM2_OTP = None | ||
| _FLASHCOMM2_ODP = get_tp_group() | ||
|
|
||
| if flashcomm2_otp_size > 1: |
There was a problem hiding this comment.
The process group creation for FlashComm2 is guarded by if flashcomm2_otp_size > 1:. This causes _FLASHCOMM2_OTP to be None when flashcomm2_oproj_tensor_parallel_size is 1. However, Flashcomm2OProjRowParallelOp is still used in this case, and it attempts to access methods on the _FLASHCOMM2_OTP group, which will lead to a crash. The logic within this if block appears to correctly handle the size == 1 case by creating groups of size 1. The conditional guard should be removed, and its content unindented, to fix this critical bug.
There was a problem hiding this comment.
In the Flashcomm2OProjRowParallelOp that uses _FLASHCOMM2_OTP , a check has been added to determine whether flashcomm2_oproj_tensor_parallel_size is 1 to avoid errors. By the way,This approach of setting it to None avoids redundant communication groups when flashcomm2_oproj_tensor_parallel_sizeis 1, reducing buffer consumption.
vllm_ascend/ascend_config.py
Outdated
| raise AssertionError( | ||
| "flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})" | ||
| ) | ||
| if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0: | ||
| raise AssertionError( | ||
| "Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})" | ||
| ) |
There was a problem hiding this comment.
The error message strings are not f-strings, so the variables inside the curly braces will not be interpolated. This will result in confusing and unhelpful error messages for users.
| raise AssertionError( | |
| "flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})" | |
| ) | |
| if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0: | |
| raise AssertionError( | |
| "Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})" | |
| ) | |
| raise AssertionError( | |
| f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})" | |
| ) | |
| if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0: | |
| raise AssertionError( | |
| f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})" | |
| ) |
| _OTP = None | ||
|
|
||
| global _FLASHCOMM2_OTP | ||
| if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: |
There was a problem hiding this comment.
The condition get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1 will prevent the _FLASHCOMM2_OTP group from being destroyed when its size is 1. If the initialization logic is fixed to create a group for size 1 (as suggested in another comment), this will cause a resource leak. The group should be destroyed if it was created, regardless of its size.
| if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: | |
| if _FLASHCOMM2_OTP: |
There was a problem hiding this comment.
_FLASHCOMM2_OTP is None when get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1
5b6c013 to
1b8cdb3
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
693c0ec to
7541e94
Compare
0f881a4 to
d990a18
Compare
834dd41 to
84eacfa
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
34ebe45 to
4622524
Compare
2.Rename the environment variable VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE 3.Normalize the enabling logic for sp/fc2 4.add TODO: Normalize the communication domain Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
2ece362 to
9229303
Compare
|
@wangxiyuan Check if this PR can be merged. |
|
Suggest raising an issue to organize the roles and relationships between flashcomm, flashcomm2, and enable_sp features, and comm domain reusing of lm_head_tp. |
Within this week, we will raise an issue to clarify the connections between these features, |
### What this PR does / why we need it? Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: zzhxx <2783294813@qq.com> Signed-off-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: zzhxx <2783294813@qq.com> Signed-off-by: luolun <luolun1995@cmbchina.com>
### What this PR does / why we need it? Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: zzhxx <2783294813@qq.com> Signed-off-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: zzhxx <2783294813@qq.com> Signed-off-by: hwhaokun <haokun0405@163.com>
Merge branch pick-fc2-1111 of git@code.alipay.com:Theta/vllm-ascend.git into dev-v0.11.0.1111 https://code.alipay.com/Theta/vllm-ascend/pull_requests/594?tab=diff Reviewed-by: 沧濯 <zhengshoujian.zsj@antgroup.com> * [Feat] flashcomm_v2 optim solution (vllm-project#3232)
### What this PR does / why we need it? Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: zzhxx <2783294813@qq.com> Signed-off-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: zzhxx <2783294813@qq.com> Signed-off-by: nsdie <yeyifan@huawei.com>
### What this PR does / why we need it? Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: zzhxx <2783294813@qq.com> Signed-off-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: zzhxx <2783294813@qq.com>
What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1.
Does this PR introduce any user-facing change?
How was this patch tested?