[Feat] Flashcomm2 use o_shared linear#4188
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Code Review
This pull request introduces support for shared o_proj linear layers for Flashcomm2, which involves changes across configuration, distributed state management, and the attention mechanism. The core logic for shared weights is implemented in vllm_ascend/torchair/ops/shared_weight_layer.py, which has been refactored for better usability.
My review focuses on ensuring the correctness and robustness of the new feature. I've identified a few critical issues:
- Incorrect validation logic for the new
flashcomm2_oproj_sharedconfiguration that could lead to silent failures. - A potential crash in the shared weight layer logic when handling a series with a single layer.
I have provided suggestions to fix these issues. The rest of the changes look good and the refactoring of the shared weight layer API is a nice improvement.
| if self.flashcomm2_oproj_tensor_parallel_size is None: | ||
| raise AssertionError( | ||
| "flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size" | ||
| ) |
There was a problem hiding this comment.
The validation if self.flashcomm2_oproj_tensor_parallel_size is None: is incorrect. The value of self.flashcomm2_oproj_tensor_parallel_size is an integer returned from get_flashcomm2_config_and_validate (which gets it from an environment variable with a default of 0), so it will never be None. The check should be against 0, as flashcomm2_oproj_shared requires flashcomm2_oproj_tensor_parallel_size to be greater than 0.
| if self.flashcomm2_oproj_tensor_parallel_size is None: | |
| raise AssertionError( | |
| "flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size" | |
| ) | |
| if self.flashcomm2_oproj_tensor_parallel_size == 0: | |
| raise AssertionError( | |
| "flashcomm2_oproj_shared must be enabled with flashcomm2_oproj_tensor_parallel_size > 0" | |
| ) |
| self.layers.sort(key=lambda x: x.layer_idx) | ||
| self.num_layers = len(self.layers) | ||
| assert self.num_layers > 0, "No layers in the series" | ||
| assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]" |
There was a problem hiding this comment.
The assertion self.prefetch_step <= self.num_layers - 2 will cause a crash if a shared weight series contains only one layer (self.num_layers == 1), because self.num_layers - 2 would be -1. For a single-layer series, prefetching is not applicable, and prefetch_step should be 0. To prevent this crash, the assertion should be adjusted to handle this edge case.
| assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]" | |
| assert self.prefetch_step >= 0 and self.prefetch_step <= max(0, self.num_layers - 2), "prefetch_step must be in [0, num_layers - 2]" |
| if flashcomm2_oproj_shared: | ||
| if flashcomm2_oproj_tp_size is None: | ||
| raise AssertionError( | ||
| "flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size" | ||
| ) | ||
| logger.info("Enable Flashcomm2 with flashcomm2_oproj_shared") |
There was a problem hiding this comment.
This validation logic for flashcomm2_oproj_shared is redundant with the logic in vllm_ascend/ascend_config.py. It's better to have validation in one place to avoid inconsistencies. Since ascend_config.py is the configuration entry point, it's a better place for this check. Additionally, the check if flashcomm2_oproj_tp_size is None: is incorrect, as flashcomm2_oproj_tp_size is an integer. I've suggested a fix in ascend_config.py and recommend removing this redundant block.
a951ad1 to
da0d630
Compare
47501e5 to
8bf8ed5
Compare
faaa68e to
bfbda42
Compare
60f1aac to
89c3923
Compare
18803f9 to
a9fae57
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
a9fae57 to
6ca00ba
Compare
26672ae to
ce92a65
Compare
fd7c9fa to
ba1a760
Compare
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
2427a23 to
0835a03
Compare
0835a03 to
3a1bf01
Compare
|
@wangxiyuan this PR is ready, please help merge it in. |
Signed-off-by: zzhxx <2783294813@qq.com>
wangxiyuan
left a comment
There was a problem hiding this comment.
please update the doc as well https://docs.vllm.ai/projects/ascend/en/latest/tutorials/DeepSeek-V3.1.html
| # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. | ||
| "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": | ||
| lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), | ||
| "VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED": |
There was a problem hiding this comment.
Add the note to describe how to use this env
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (#4913) ### What this PR does / why we need it? In PR #4188, a small bug was introduced that caused sfa-cp to be unable to find the global_pp_size parameter during initialization, and this PR fixed the issue. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
| group_ranks = [] | ||
| for pp_idx in range(global_pp_size): | ||
| group = [] | ||
| for dp_idx in range(global_dp_size): |
There was a problem hiding this comment.
How can adapt this to PCP?
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (vllm-project#4913) ### What this PR does / why we need it? In PR vllm-project#4188, a small bug was introduced that caused sfa-cp to be unable to find the global_pp_size parameter during initialization, and this PR fixed the issue. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (vllm-project#4913) ### What this PR does / why we need it? In PR vllm-project#4188, a small bug was introduced that caused sfa-cp to be unable to find the global_pp_size parameter during initialization, and this PR fixed the issue. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (vllm-project#4913) ### What this PR does / why we need it? In PR vllm-project#4188, a small bug was introduced that caused sfa-cp to be unable to find the global_pp_size parameter during initialization, and this PR fixed the issue. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
What this PR does / why we need it?
It is mentioned in the flashcomm2 technical report that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication.
We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue.
This PR depends on #3232 and #2931
Flashcomm2 flowchart
Does this PR introduce any user-facing change?
Use environment variables