-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. #5181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. #5181
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
70c6ccb
Change the interface of falshcomm2 o shard
zzhx1 9028e2b
rename the shared as shard
zzhx1 851e87d
fix bug
zzhx1 2003960
fix is_hidden_layer bug
zzhx1 b5b95cf
Fix indentation issue
zzhx1 08b3ed1
refactor shard_weight comm_group and support all TP
zzhx1 45e147f
refactor dsa_cp interface
zzhx1 b116e05
refactor create shard weight in parallel_state
zzhx1 fde18dc
fix parallel_state
zzhx1 58783e1
refactor sfa_v1
zzhx1 e4b4b50
refactor mla-cp
zzhx1 4e31818
fix lint
zzhx1 c80c199
fix UT bug
zzhx1 c5d06a9
Add the relevant logger
zzhx1 5977d07
add doc for additional-config
zzhx1 693af10
Change get_num_hidden_layer
zzhx1 ad82999
add doc for layer_sharding user_guide
zzhx1 5ebab5a
docs fix
Kurumi5210 cb18d41
fix e2e
zzhx1 9bd6658
according to the revised comments, add some logger and error
zzhx1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| --- | ||
| title: Layer Sharding Guide | ||
| --- | ||
|
|
||
| # Overview | ||
|
|
||
| **Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights. | ||
|
|
||
| Instead of replicating all weights on every device, **Layer Shard Linear shards the weights of a "series" of such operators across the NPU devices in a communication group**: | ||
| - The **i-th layer's linear weight** is stored **only on device `i % K`**, where `K` is the number of devices in the group. | ||
| - Other devices hold a lightweight **shared dummy tensor** during initialization and fetch the real weight **on-demand via asynchronous broadcast** during the forward pass. | ||
|
|
||
| As illustrated in the figure below, this design enables broadcast to reach weights: while the current layer (e.g., MLA or MOE) is being computed, the system **asynchronously broadcasts the next layer's weight** in the background. Because the attention computation in the MLA module is sufficiently latency-bound, the weight transfer for `o_proj` is **fully overlapped with computation**, making the communication **latency-free from the perspective of end-to-end inference**. | ||
|
|
||
| This approach **preserves exact computational semantics** while **significantly reducing NPU memory footprint**, especially critical for: | ||
| - Extremely deep architectures (e.g., DeepSeek-V3/R1 with 61 layers); | ||
| - Models using **[DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702)** or **[FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188)**, where the full `O` (output) projection matrix must reside in memory per layer; | ||
| - Scenarios where **attention computation latency fully overlaps** (hides) the communication cost of weight broadcasting. | ||
|
|
||
| --- | ||
|
|
||
| ## Flowchart | ||
|  | ||
|
|
||
| > **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading. | ||
|
|
||
| --- | ||
|
|
||
| # Getting Started | ||
|
|
||
| To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use: | ||
|
|
||
| ```bash | ||
| --additional-config '{ | ||
| "layer_sharding": ["o_proj", "q_b_proj"] | ||
| }' | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| # Supported Scenarios | ||
|
|
||
| This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases: | ||
|
|
||
| ## FlashComm2-enabled | ||
|
|
||
| When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices. | ||
|
|
||
| **Example configuration:** | ||
|
|
||
| ```bash | ||
| export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 | ||
| vllm serve \ | ||
| --model DeepSeek-V3/R1 \ | ||
| --additional-config '{ | ||
| "layer_sharding": ["o_proj"] | ||
| }' | ||
| ``` | ||
|
|
||
| ## DSA-CP-enabled | ||
|
|
||
| With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory. | ||
|
|
||
| **Example configuration:** | ||
|
|
||
| ```bash | ||
| export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 | ||
| vllm serve \ | ||
| --model DeepSeek-V3.2 \ | ||
| --additional-config '{ | ||
| "layer_sharding": ["q_b_proj", "o_proj"] | ||
| }' | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not compatible for GQA when using kwargs to get layer_name ?