Skip to content

[Feat] Flashcomm2 use o_shared linear#4188

Merged
ApsarasX merged 15 commits intovllm-project:mainfrom
zzhx1:flashcomm_oshared
Dec 11, 2025
Merged

[Feat] Flashcomm2 use o_shared linear#4188
ApsarasX merged 15 commits intovllm-project:mainfrom
zzhx1:flashcomm_oshared

Conversation

@zzhx1
Copy link
Copy Markdown
Contributor

@zzhx1 zzhx1 commented Nov 13, 2025

What this PR does / why we need it?

It is mentioned in the flashcomm2 technical report that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication.

We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue.

This PR depends on #3232 and #2931

Flashcomm2 flowchart

PixPin_2025-11-14_13-37-39

Does this PR introduce any user-facing change?

Use environment variables

export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for shared o_proj linear layers for Flashcomm2, which involves changes across configuration, distributed state management, and the attention mechanism. The core logic for shared weights is implemented in vllm_ascend/torchair/ops/shared_weight_layer.py, which has been refactored for better usability.

My review focuses on ensuring the correctness and robustness of the new feature. I've identified a few critical issues:

  • Incorrect validation logic for the new flashcomm2_oproj_shared configuration that could lead to silent failures.
  • A potential crash in the shared weight layer logic when handling a series with a single layer.

I have provided suggestions to fix these issues. The rest of the changes look good and the refactoring of the shared weight layer API is a nice improvement.

Comment thread vllm_ascend/ascend_config.py Outdated
Comment on lines +137 to +140
if self.flashcomm2_oproj_tensor_parallel_size is None:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The validation if self.flashcomm2_oproj_tensor_parallel_size is None: is incorrect. The value of self.flashcomm2_oproj_tensor_parallel_size is an integer returned from get_flashcomm2_config_and_validate (which gets it from an environment variable with a default of 0), so it will never be None. The check should be against 0, as flashcomm2_oproj_shared requires flashcomm2_oproj_tensor_parallel_size to be greater than 0.

Suggested change
if self.flashcomm2_oproj_tensor_parallel_size is None:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size"
)
if self.flashcomm2_oproj_tensor_parallel_size == 0:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled with flashcomm2_oproj_tensor_parallel_size > 0"
)

self.layers.sort(key=lambda x: x.layer_idx)
self.num_layers = len(self.layers)
assert self.num_layers > 0, "No layers in the series"
assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion self.prefetch_step <= self.num_layers - 2 will cause a crash if a shared weight series contains only one layer (self.num_layers == 1), because self.num_layers - 2 would be -1. For a single-layer series, prefetching is not applicable, and prefetch_step should be 0. To prevent this crash, the assertion should be adjusted to handle this edge case.

Suggested change
assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]"
assert self.prefetch_step >= 0 and self.prefetch_step <= max(0, self.num_layers - 2), "prefetch_step must be in [0, num_layers - 2]"

Comment thread vllm_ascend/utils.py Outdated
Comment on lines +868 to +873
if flashcomm2_oproj_shared:
if flashcomm2_oproj_tp_size is None:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size"
)
logger.info("Enable Flashcomm2 with flashcomm2_oproj_shared")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This validation logic for flashcomm2_oproj_shared is redundant with the logic in vllm_ascend/ascend_config.py. It's better to have validation in one place to avoid inconsistencies. Since ascend_config.py is the configuration entry point, it's a better place for this check. Additionally, the check if flashcomm2_oproj_tp_size is None: is incorrect, as flashcomm2_oproj_tp_size is an integer. I've suggested a fix in ascend_config.py and recommend removing this redundant block.

@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch from a951ad1 to da0d630 Compare November 14, 2025 05:05
@zzhx1 zzhx1 changed the title Flashcomm2 use o_shared linear [Feat] Flashcomm2 use o_shared linear Nov 14, 2025
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 3 times, most recently from 47501e5 to 8bf8ed5 Compare November 14, 2025 09:14
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 2 times, most recently from faaa68e to bfbda42 Compare November 17, 2025 05:08
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch from 60f1aac to 89c3923 Compare November 24, 2025 07:12
Comment thread vllm_ascend/distributed/parallel_state.py Outdated
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 3 times, most recently from 18803f9 to a9fae57 Compare December 1, 2025 08:28
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Dec 1, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Comment thread vllm_ascend/distributed/parallel_state.py Outdated
Comment thread vllm_ascend/attention/mla_v1.py
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 4 times, most recently from fd7c9fa to ba1a760 Compare December 4, 2025 07:40
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 2 times, most recently from 2427a23 to 0835a03 Compare December 6, 2025 17:39
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch from 0835a03 to 3a1bf01 Compare December 6, 2025 18:04
@zzhx1
Copy link
Copy Markdown
Contributor Author

zzhx1 commented Dec 7, 2025

@wangxiyuan this PR is ready, please help merge it in.

Copy link
Copy Markdown
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread vllm_ascend/envs.py
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the note to describe how to use this env

@ApsarasX ApsarasX merged commit eac72f5 into vllm-project:main Dec 11, 2025
16 of 18 checks passed
jianzs added a commit that referenced this pull request Dec 15, 2025
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (#4913)

### What this PR does / why we need it?
In PR #4188, a small bug was introduced that caused sfa-cp to be unable
to find the global_pp_size parameter during initialization, and this PR
fixed the issue.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
group_ranks = []
for pp_idx in range(global_pp_size):
group = []
for dp_idx in range(global_dp_size):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can adapt this to PCP?

chenaoxuan pushed a commit to chenaoxuan/vllm-ascend that referenced this pull request Dec 20, 2025
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (vllm-project#4913)

### What this PR does / why we need it?
In PR vllm-project#4188, a small bug was introduced that caused sfa-cp to be unable
to find the global_pp_size parameter during initialization, and this PR
fixed the issue.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (vllm-project#4913)

### What this PR does / why we need it?
In PR vllm-project#4188, a small bug was introduced that caused sfa-cp to be unable
to find the global_pp_size parameter during initialization, and this PR
fixed the issue.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…domain in sfa-cp, and fix the mtp weight load in pp>1 situation (vllm-project#4913)

### What this PR does / why we need it?
In PR vllm-project#4188, a small bug was introduced that caused sfa-cp to be unable
to find the global_pp_size parameter during initialization, and this PR
fixed the issue.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants