Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,29 @@ def test_qwen3_moe_fc2_tp2() -> None:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"})
def test_qwen3_moe_fc2_oshard_tp2() -> None:
example_prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)

with VllmRunner(
snapshot_download("Qwen/Qwen3-30B-A3B"),
dtype="auto",
tensor_parallel_size=2,
distributed_executor_backend="mp",
enable_expert_parallel=True,
enforce_eager=
True, # TODO(Levi-JQ): support graph mode for fc2 in Qwen
additional_config={"layer_sharding": ["o_proj"]}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_deepseek_v2_lite_fc1_tp2() -> None:
example_prompts = [
Expand Down
6 changes: 6 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces)
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
weak_ref_tensors)

Expand Down Expand Up @@ -349,6 +350,11 @@ def __init__(
self.value_cache = None
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer

def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.post_process_after_loading()

def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: torch.Tensor) -> torch.Tensor:
Expand Down
100 changes: 100 additions & 0 deletions vllm_ascend/ops/flashcomm2_oshard_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Any, Dict, Optional
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this file


from vllm.model_executor.models.utils import extract_layer_index

from vllm_ascend.distributed.parallel_state import get_shard_weight_group
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series, register_layer_to_shard_weight_series)
from vllm_ascend.utils import flashcomm2_enable, o_shard_enable


class Flashcomm2OShardManager:
"""Manages sharded layers for the FlashComm2 O-Shard feature.

This class is implemented to centralize all logic related to Flashcomm2OShard layers.
Its main responsibilities are:
1. Registering Attention `o_proj` layers that require O-Sharding.
2. Storing and managing these layers in a dictionary mapping layer indices
to layer objects (`layer_index -> layer`).
3. Providing a high-level API for external callers to use at key stages
like model initialization, computation, and weight loading.

Attributes:
_shard_layers: A dictionary to store the registered sharded layers,
mapping a layer index (int) to its corresponding layer object.
"""

def __init__(self):
self._shard_layers: Dict[int, Any] = {}

def flashcomm2_oshard_enable(self):
return flashcomm2_enable() and o_shard_enable()

def register_layer(self, layer: Any, prefetch_step: int = 1):
"""Registers a layer for O-Sharding.

This method first checks if the O-Shard feature is enabled and if the
provided layer qualifies as a target (e.g., a hidden layer). If so,
it performs two actions:
1. Caches the layer internally in the `_shard_layers` dictionary.
2. Calls the underlying `register_layer_to_shared_weight_series`
function to register it for communication.

Args:
layer: The layer object to be registered.
prefetch_step: The prefetch step to be used when registering the
layer to the shared weight series.
"""
# Check if the layer is a target for sharding.
if is_hidden_layer(layer):
layer_idx = extract_layer_index(layer.prefix)
self._shard_layers[layer_idx] = layer

register_layer_to_shard_weight_series(
series_name="o_proj",
group=get_shard_weight_group(),
layer=layer,
prefetch_step=prefetch_step)

def get_layer(self, layer_idx: int) -> Optional[Any]:
"""Safely retrieves a registered layer by its index.

Args:
layer_idx: The index of the layer to retrieve.

Returns:
The layer object if found, otherwise None.
"""
return self._shard_layers.get(layer_idx)

def trigger_broadcast_for_layer(self, layer_prefix: str):
"""Triggers a broadcast for a specific layer during model computation.

This method is intended to be called within a layer's forward pass.
It extracts the layer index from the prefix, retrieves the corresponding
registered layer object, and then triggers the broadcast operation
if all conditions are met.

Args:
layer_prefix: The name prefix of the current layer being computed.
"""
layer_idx = extract_layer_index(layer_prefix)
target_layer = self.get_layer(layer_idx)

# Ensure the layer exists and meets the sharding criteria.
if target_layer and is_hidden_layer(target_layer):
reach_layer_for_shard_weight_series(target_layer)

def post_process_after_loading(self):
"""Performs post-processing on all registered layers after weight loading.

This should be called once after the model weights have been fully loaded.
"""
if self._shard_layers:
# Pick any layer (e.g., the first one) to trigger the shard post-processing
any_layer = next(iter(self._shard_layers.values()))
post_process_after_loading_for_shard_weight_series(any_layer)


flashcomm2_oshard_manager = Flashcomm2OShardManager()
44 changes: 43 additions & 1 deletion vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
├── CustomColumnParallelOp
│ ├── MLPColumnParallelOp
│ ├── SequenceColumnParallelOp
│ ├── Flashcomm2OshardQKVParallelOp
└── CustomRowParallelOp
│ ├── MLPRowParallelOp
│ ├── OProjRowParallelOp
Expand Down Expand Up @@ -60,6 +61,7 @@
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
Expand Down Expand Up @@ -400,6 +402,9 @@ def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.input_size_per_partition = self.layer.input_size_per_partition
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.register_layer(self.layer,
prefetch_step=1)


class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
Expand Down Expand Up @@ -479,6 +484,39 @@ def apply_impl(
return output, output_bias


class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):

def __init__(self, layer):
super().__init__(layer)

def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Column-parallel linear with FlashComm2 OShard optimization."""

bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
assert self.quant_method is not None

if enable_sp():
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, True)

# Trigger async broadcast before matmul to overlap communication.
flashcomm2_oshard_manager.trigger_broadcast_for_layer(
self.layer.prefix)

output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias


class SequenceRowParallelOp(CustomRowParallelOp):

def __init__(self, layer):
Expand Down Expand Up @@ -657,12 +695,15 @@ def apply_impl(
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
ShardedCPColumnParallelOp]]:
ShardedCPColumnParallelOp, Flashcomm2OshardQKVParallelOp]]:
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
return ShardedCPColumnParallelOp(layer)
if "gate_up_proj" in prefix and mlp_tp_enable(
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
return Flashcomm2OshardQKVParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
Expand Down Expand Up @@ -719,6 +760,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp,
Flashcomm2OshardQKVParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp, ShardedCPRowParallelOp,
ShardedCPColumnParallelOp]] = None
Expand Down
7 changes: 7 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,13 @@ def flashcomm2_enable() -> bool:
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0


def o_shard_enable() -> bool:
layer_sharding = get_ascend_config().layer_sharding
if layer_sharding is None:
return False
return "o_proj" in layer_sharding


def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
Expand Down