Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The following table lists additional configuration options available in vLLM Asc
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `layer_sharding` | dict | `{}` | Configuration options for layer sharding linear |

The details of each configuration option are as follows:

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/user_guide/feature_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ external_dp
large_scale_ep
ucm_deployment
Fine_grained_TP
layer_sharding
speculative_decoding
context_parallel
:::
73 changes: 73 additions & 0 deletions docs/source/user_guide/feature_guide/layer_sharding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
---
title: Layer Sharding Guide
---

# Overview

**Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights.

Instead of replicating all weights on every device, **Layer Shard Linear shards the weights of a "series" of such operators across the NPU devices in a communication group**:
- The **i-th layer's linear weight** is stored **only on device `i % K`**, where `K` is the number of devices in the group.
- Other devices hold a lightweight **shared dummy tensor** during initialization and fetch the real weight **on-demand via asynchronous broadcast** during the forward pass.

As illustrated in the figure below, this design enables broadcast to reach weights: while the current layer (e.g., MLA or MOE) is being computed, the system **asynchronously broadcasts the next layer's weight** in the background. Because the attention computation in the MLA module is sufficiently latency-bound, the weight transfer for `o_proj` is **fully overlapped with computation**, making the communication **latency-free from the perspective of end-to-end inference**.

This approach **preserves exact computational semantics** while **significantly reducing NPU memory footprint**, especially critical for:
- Extremely deep architectures (e.g., DeepSeek-V3/R1 with 61 layers);
- Models using **[DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702)** or **[FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188)**, where the full `O` (output) projection matrix must reside in memory per layer;
- Scenarios where **attention computation latency fully overlaps** (hides) the communication cost of weight broadcasting.

---

## Flowchart
![layer shard](./images/layer_sharding.png)

> **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading.

---

# Getting Started

To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use:

```bash
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```

---

# Supported Scenarios

This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases:

## FlashComm2-enabled

When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices.

**Example configuration:**

```bash
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
vllm serve \
--model DeepSeek-V3/R1 \
--additional-config '{
"layer_sharding": ["o_proj"]
}'
```

## DSA-CP-enabled

With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory.

**Example configuration:**

```bash
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
vllm serve \
--model DeepSeek-V3.2 \
--additional-config '{
"layer_sharding": ["q_b_proj", "o_proj"]
}'
```
7 changes: 1 addition & 6 deletions tests/ut/distributed/test_parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,10 @@ def mock_distributed():
patch('torch.distributed.get_world_size', return_value=16), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
mock_tp_group.return_value.world_size = 4
mock_dp_group.return_value.world_size = 2
mock_pp_group.return_value.world_size = 2
yield


Expand All @@ -50,7 +46,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_vllm_config.kv_transfer_config.is_kv_producer = True
mock_envs_ascend = MagicMock()
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
Expand Down
8 changes: 7 additions & 1 deletion vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def __init__(self, vllm_config: "VllmConfig"):
"weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(
weight_prefetch_config)
self.layer_sharding = additional_config.get("layer_sharding", None)
logger.info_once(
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
"Note: This feature works optimally with FLASHCOMM2 and DSA-CP enabled; "
"using it without these features may result in significant performance degradation."
)

# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
self.expert_map_path = additional_config.get("expert_map_path", None)
Expand Down Expand Up @@ -111,7 +117,7 @@ def __init__(self, vllm_config: "VllmConfig"):
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
self, vllm_config)
self.enable_npugraph_ex = additional_config.get(
"enable_npugraph_ex", False)
Expand Down
50 changes: 23 additions & 27 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces)
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
flashcomm2_o_shared_enabled, maybe_trans_nz,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

Expand Down Expand Up @@ -734,18 +733,6 @@ def __init__(
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.vllm_config = get_current_vllm_config()
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()

if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)

self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
Expand All @@ -762,6 +749,15 @@ def __init__(
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO

self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
Copy link
Copy Markdown
Contributor

@Levi-JQ Levi-JQ Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not compatible for GQA when using kwargs to get layer_name ?

self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)

def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
Expand Down Expand Up @@ -833,9 +829,9 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)

if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)

def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
Expand Down Expand Up @@ -1445,9 +1441,9 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)

if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)

decode_preprocess_res = None
prefill_preprocess_res = None
Expand Down Expand Up @@ -1478,9 +1474,9 @@ def forward(
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)

forward_context = get_forward_context()
Expand Down
57 changes: 23 additions & 34 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
Expand Down Expand Up @@ -374,22 +374,17 @@ def __init__(
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size

# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
self._replace_linear_class_for_sfa_cp()
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
if is_hidden_layer(self.vllm_config, self.q_proj):
register_layer_to_shared_weight_series(
series_name="q_proj",
group=get_shared_weight_group(),
layer=self.q_proj,
prefetch_step=1)
if is_hidden_layer(self.vllm_config, self.o_proj):
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)

# indexer param
self.n_head: int = self.indexer.n_head # 64
Expand Down Expand Up @@ -434,14 +429,10 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):

# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj)

if self.enable_sfa_cp:
if is_hidden_layer(self.vllm_config, self.q_proj):
post_process_after_loading_for_shared_weight_series(
self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(
self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)

if self.enable_mlapo:
quant_method = getattr(
Expand Down Expand Up @@ -751,10 +742,9 @@ def forward(
if attn_metadata is None:
# Profiling run.
if self.enable_sfa_cp and not forward_context.in_profile_run:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
has_prefill = attn_metadata.has_prefill
cos = attn_metadata.cos
Expand Down Expand Up @@ -809,10 +799,9 @@ def forward(
slot_mapping_cp)

if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)

ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
Expand Down
Loading
Loading