diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index c5194c63fec..d06dece3446 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -157,6 +157,29 @@ def test_qwen3_moe_fc2_tp2() -> None: vllm_model.generate(example_prompts, sampling_params) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"}) +def test_qwen3_moe_fc2_oshard_tp2() -> None: + example_prompts = [ + "Hello, my name is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner( + snapshot_download("Qwen/Qwen3-30B-A3B"), + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager= + True, # TODO(Levi-JQ): support graph mode for fc2 in Qwen + additional_config={"layer_sharding": ["o_proj"]}) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) def test_deepseek_v2_lite_fc1_tp2() -> None: example_prompts = [ diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d19d3369b99..80ebf17dfb4 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -43,6 +43,7 @@ from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces) +from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, weak_ref_tensors) @@ -349,6 +350,11 @@ def __init__( self.value_cache = None self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + def process_weights_after_loading(self, act_dtype: torch.dtype): + super().process_weights_after_loading(act_dtype) + if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): + flashcomm2_oshard_manager.post_process_after_loading() + def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor) -> torch.Tensor: diff --git a/vllm_ascend/ops/flashcomm2_oshard_manager.py b/vllm_ascend/ops/flashcomm2_oshard_manager.py new file mode 100644 index 00000000000..d95748e6554 --- /dev/null +++ b/vllm_ascend/ops/flashcomm2_oshard_manager.py @@ -0,0 +1,100 @@ +from typing import Any, Dict, Optional + +from vllm.model_executor.models.utils import extract_layer_index + +from vllm_ascend.distributed.parallel_state import get_shard_weight_group +from vllm_ascend.ops.layer_shard_linear import ( + is_hidden_layer, post_process_after_loading_for_shard_weight_series, + reach_layer_for_shard_weight_series, register_layer_to_shard_weight_series) +from vllm_ascend.utils import flashcomm2_enable, o_shard_enable + + +class Flashcomm2OShardManager: + """Manages sharded layers for the FlashComm2 O-Shard feature. + + This class is implemented to centralize all logic related to Flashcomm2OShard layers. + Its main responsibilities are: + 1. Registering Attention `o_proj` layers that require O-Sharding. + 2. Storing and managing these layers in a dictionary mapping layer indices + to layer objects (`layer_index -> layer`). + 3. Providing a high-level API for external callers to use at key stages + like model initialization, computation, and weight loading. + + Attributes: + _shard_layers: A dictionary to store the registered sharded layers, + mapping a layer index (int) to its corresponding layer object. + """ + + def __init__(self): + self._shard_layers: Dict[int, Any] = {} + + def flashcomm2_oshard_enable(self): + return flashcomm2_enable() and o_shard_enable() + + def register_layer(self, layer: Any, prefetch_step: int = 1): + """Registers a layer for O-Sharding. + + This method first checks if the O-Shard feature is enabled and if the + provided layer qualifies as a target (e.g., a hidden layer). If so, + it performs two actions: + 1. Caches the layer internally in the `_shard_layers` dictionary. + 2. Calls the underlying `register_layer_to_shared_weight_series` + function to register it for communication. + + Args: + layer: The layer object to be registered. + prefetch_step: The prefetch step to be used when registering the + layer to the shared weight series. + """ + # Check if the layer is a target for sharding. + if is_hidden_layer(layer): + layer_idx = extract_layer_index(layer.prefix) + self._shard_layers[layer_idx] = layer + + register_layer_to_shard_weight_series( + series_name="o_proj", + group=get_shard_weight_group(), + layer=layer, + prefetch_step=prefetch_step) + + def get_layer(self, layer_idx: int) -> Optional[Any]: + """Safely retrieves a registered layer by its index. + + Args: + layer_idx: The index of the layer to retrieve. + + Returns: + The layer object if found, otherwise None. + """ + return self._shard_layers.get(layer_idx) + + def trigger_broadcast_for_layer(self, layer_prefix: str): + """Triggers a broadcast for a specific layer during model computation. + + This method is intended to be called within a layer's forward pass. + It extracts the layer index from the prefix, retrieves the corresponding + registered layer object, and then triggers the broadcast operation + if all conditions are met. + + Args: + layer_prefix: The name prefix of the current layer being computed. + """ + layer_idx = extract_layer_index(layer_prefix) + target_layer = self.get_layer(layer_idx) + + # Ensure the layer exists and meets the sharding criteria. + if target_layer and is_hidden_layer(target_layer): + reach_layer_for_shard_weight_series(target_layer) + + def post_process_after_loading(self): + """Performs post-processing on all registered layers after weight loading. + + This should be called once after the model weights have been fully loaded. + """ + if self._shard_layers: + # Pick any layer (e.g., the first one) to trigger the shard post-processing + any_layer = next(iter(self._shard_layers.values())) + post_process_after_loading_for_shard_weight_series(any_layer) + + +flashcomm2_oshard_manager = Flashcomm2OShardManager() diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index ef612ce9282..4a90ea90498 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -21,6 +21,7 @@ ├── CustomColumnParallelOp │ ├── MLPColumnParallelOp │ ├── SequenceColumnParallelOp +│ ├── Flashcomm2OshardQKVParallelOp └── CustomRowParallelOp │ ├── MLPRowParallelOp │ ├── OProjRowParallelOp @@ -60,6 +61,7 @@ get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) +from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, @@ -400,6 +402,9 @@ def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.input_size_per_partition = self.layer.input_size_per_partition + if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): + flashcomm2_oshard_manager.register_layer(self.layer, + prefetch_step=1) class MatmulAllreduceRowParallelOp(CustomRowParallelOp): @@ -479,6 +484,39 @@ def apply_impl( return output, output_bias +class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): + + def __init__(self, layer): + super().__init__(layer) + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Column-parallel linear with FlashComm2 OShard optimization.""" + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + if enable_sp(): + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + input_, True) + + # Trigger async broadcast before matmul to overlap communication. + flashcomm2_oshard_manager.trigger_broadcast_for_layer( + self.layer.prefix) + + output_parallel = self.quant_method.apply(self.layer, input_, bias) + if self.gather_output and self.tp_size > 1: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + class SequenceRowParallelOp(CustomRowParallelOp): def __init__(self, layer): @@ -657,12 +695,15 @@ def apply_impl( def _get_column_parallel_op( prefix, layer ) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, - ShardedCPColumnParallelOp]]: + ShardedCPColumnParallelOp, Flashcomm2OshardQKVParallelOp]]: if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix): return ShardedCPColumnParallelOp(layer) if "gate_up_proj" in prefix and mlp_tp_enable( ) and not is_moe_layer(prefix): return MLPColumnParallelOp(layer) + if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): + if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")): + return Flashcomm2OshardQKVParallelOp(layer) if enable_sp(): if "shared_expert" in prefix: return None @@ -719,6 +760,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct): custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, MLPRowParallelOp, OProjRowParallelOp, Flashcomm2OProjRowParallelOp, + Flashcomm2OshardQKVParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp, ShardedCPRowParallelOp, ShardedCPColumnParallelOp]] = None diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 9874661ecc6..0000d696498 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1017,6 +1017,13 @@ def flashcomm2_enable() -> bool: return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0 +def o_shard_enable() -> bool: + layer_sharding = get_ascend_config().layer_sharding + if layer_sharding is None: + return False + return "o_proj" in layer_sharding + + def get_flashcomm2_config_and_validate(ascend_config, vllm_config): flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE global_tp_size = vllm_config.parallel_config.tensor_parallel_size