From 56c288dd3b94c7a730e13b84eff1b228ea5411a9 Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Mon, 15 Dec 2025 16:13:39 +0800 Subject: [PATCH 1/9] flashcomm2+oshard Generalized Signed-off-by: Levi-JQ --- vllm_ascend/attention/attention_v1.py | 8 +++- vllm_ascend/ops/linear_op.py | 55 ++++++++++++++++++++++++++- vllm_ascend/utils.py | 16 ++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d19d3369b99..2e2491837a5 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -44,7 +44,8 @@ get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces) from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, - weak_ref_tensors) + weak_ref_tensors, flashcomm2_o_shared_enabled, get_flashcomm2_o_shard_layer) +from vllm_ascend.ops.shared_weight_layer import post_process_after_loading_for_shared_weight_series # default max value of sliding window size SWA_INT_MAX = 2147483647 @@ -349,6 +350,11 @@ def __init__( self.value_cache = None self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + def process_weights_after_loading(self, act_dtype: torch.dtype): + super().process_weights_after_loading(act_dtype) + if flashcomm2_o_shared_enabled(): + post_process_after_loading_for_shared_weight_series(get_flashcomm2_o_shard_layer()) + def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor) -> torch.Tensor: diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index ef612ce9282..19816ac69fe 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -55,6 +55,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend import envs as envs_ascend +from vllm.model_executor.models.utils import extract_layer_index from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_flashcomm2_otp_group, @@ -63,8 +64,12 @@ from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, - oproj_tp_enable, shared_expert_dp_enabled) - + oproj_tp_enable, register_flashcomm2_o_shard_layer, shared_expert_dp_enabled) +from vllm_ascend.ops.shared_weight_layer import ( + is_hidden_layer, + reach_layer_for_shared_weight_series, + register_layer_to_shared_weight_series) +from vllm.config import get_current_vllm_config class CustomLinearOp: @@ -400,6 +405,14 @@ def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.input_size_per_partition = self.layer.input_size_per_partition + if flashcomm2_o_shared_enabled() and is_hidden_layer(get_current_vllm_config(), register_flashcomm2_o_shard_layer(self.layer)): + from vllm_ascend.distributed.parallel_state import \ + get_shared_weight_group + register_layer_to_shared_weight_series( + series_name="o_proj", + group=get_shared_weight_group(), + layer=self.layer, + prefetch_step=1) class MatmulAllreduceRowParallelOp(CustomRowParallelOp): @@ -477,6 +490,41 @@ def apply_impl( output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias + +# 该层用于flashcomm2开启oshard后,在QKV matmul之前调用异步broadcast,从而实现通算掩盖 +class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): + + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + layer_idx = extract_layer_index(self.layer.prefix) + if flashcomm2_o_shared_enabled() and is_hidden_layer(get_current_vllm_config(), get_flashcomm2_o_shard_layer(layer_idx)): + reach_layer_for_shared_weight_series(get_flashcomm2_o_shard_layer(layer_idx)) + # if flashcomm2_o_shared_enable(): + # from vllm_ascend.multistream.context import get_multistream_microbatch_context + # if get_multistream_microbatch_context() != 0: + # #TODO: 建立Oshard适配层,包括几个关键函数:初始化后处理;权重加载后处理;异步broadcast调用 + # reach_layer_for_shared_weight_series(self.o_proj) + output_parallel = self.quant_method.apply(self.layer, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias class SequenceRowParallelOp(CustomRowParallelOp): @@ -673,6 +721,9 @@ def _get_column_parallel_op( "conv1d", # gated deltanet of Qwen3 Next "query_key_value", # qkv linear of Bailing ] + if flashcomm2_enable() and flashcomm2_o_shared_enabled(): + if "qkv_proj" in prefix: + return Flashcomm2OshardQKVParallelOp(layer) for a_prefix in sp_column_prefix: if a_prefix in prefix: return SequenceColumnParallelOp(layer) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 9874661ecc6..2451cdb654d 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1017,6 +1017,22 @@ def flashcomm2_enable() -> bool: return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0 +def flashcomm2_o_shared_enabled() -> bool: + return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED + +def register_flashcomm2_o_shard_layer(layer = None): + global _FLASHCOMM2_OSHARD_LAYER + # layer_idx = extract_layer_index(layer.prefix) + _FLASHCOMM2_OSHARD_LAYER.append(layer) + # if _FLASHCOMM2_OSHARD_LAYER is None: + # _FLASHCOMM2_OSHARD_LAYER = layer + # assert _FLASHCOMM2_OSHARD_LAYER is not None, f"_FLASHCOMM2_OSHARD_LAYER is not init, please make sure that you input a valid layer parameter" + return layer + +def get_flashcomm2_o_shard_layer(layer_idx = 0): + global _FLASHCOMM2_OSHARD_LAYER + return _FLASHCOMM2_OSHARD_LAYER[layer_idx] + def get_flashcomm2_config_and_validate(ascend_config, vllm_config): flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE global_tp_size = vllm_config.parallel_config.tensor_parallel_size From 90b2bf465cc4edabb3c231e90b107b6e843b967b Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Mon, 15 Dec 2025 16:53:18 +0800 Subject: [PATCH 2/9] fix get vllm_config Signed-off-by: Levi-JQ --- vllm_ascend/ops/linear_op.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 19816ac69fe..09efb935ed3 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -288,6 +288,7 @@ def __init__(self, layer): get_tp_group().world_size) self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() self.layer._quant_comm_config = {} + self.vllm_config = get_current_vllm_config() @property def comm_group(self): @@ -405,7 +406,7 @@ def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.input_size_per_partition = self.layer.input_size_per_partition - if flashcomm2_o_shared_enabled() and is_hidden_layer(get_current_vllm_config(), register_flashcomm2_o_shard_layer(self.layer)): + if flashcomm2_o_shared_enabled() and is_hidden_layer(self.vllm_config, register_flashcomm2_o_shard_layer(self.layer)): from vllm_ascend.distributed.parallel_state import \ get_shared_weight_group register_layer_to_shared_weight_series( @@ -494,6 +495,10 @@ def apply_impl( # 该层用于flashcomm2开启oshard后,在QKV matmul之前调用异步broadcast,从而实现通算掩盖 class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): + def __init__(self, layer): + super().__init__(layer) + self.vllm_config = get_current_vllm_config() + def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: @@ -510,7 +515,7 @@ def apply_impl( input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) layer_idx = extract_layer_index(self.layer.prefix) - if flashcomm2_o_shared_enabled() and is_hidden_layer(get_current_vllm_config(), get_flashcomm2_o_shard_layer(layer_idx)): + if flashcomm2_o_shared_enabled() and is_hidden_layer(self.vllm_config, get_flashcomm2_o_shard_layer(layer_idx)): reach_layer_for_shared_weight_series(get_flashcomm2_o_shard_layer(layer_idx)) # if flashcomm2_o_shared_enable(): # from vllm_ascend.multistream.context import get_multistream_microbatch_context From d83c0e4e3b1d787587db8ffc8f000b4b6e67031c Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Mon, 15 Dec 2025 18:04:40 +0800 Subject: [PATCH 3/9] fix ci Signed-off-by: Levi-JQ --- vllm_ascend/attention/attention_v1.py | 5 +++-- vllm_ascend/ops/linear_op.py | 29 ++++++++++++++++----------- vllm_ascend/utils.py | 7 +++++-- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 2e2491837a5..b34d45539f4 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -45,7 +45,7 @@ update_draft_graph_params_workspaces, update_graph_params_workspaces) from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, weak_ref_tensors, flashcomm2_o_shared_enabled, get_flashcomm2_o_shard_layer) -from vllm_ascend.ops.shared_weight_layer import post_process_after_loading_for_shared_weight_series +from vllm_ascend.ops.layer_shard_linear import post_process_after_loading_for_shard_weight_series # default max value of sliding window size SWA_INT_MAX = 2147483647 @@ -353,7 +353,8 @@ def __init__( def process_weights_after_loading(self, act_dtype: torch.dtype): super().process_weights_after_loading(act_dtype) if flashcomm2_o_shared_enabled(): - post_process_after_loading_for_shared_weight_series(get_flashcomm2_o_shard_layer()) + post_process_after_loading_for_shard_weight_series( + get_flashcomm2_o_shard_layer()) def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 09efb935ed3..981e50ce2d6 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -48,6 +48,7 @@ from torch import nn from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from vllm.config import get_current_vllm_config from vllm.distributed import (split_tensor_along_last_dim, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) @@ -56,6 +57,7 @@ from vllm_ascend import envs as envs_ascend from vllm.model_executor.models.utils import extract_layer_index + from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_flashcomm2_otp_group, @@ -64,12 +66,10 @@ from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, - oproj_tp_enable, register_flashcomm2_o_shard_layer, shared_expert_dp_enabled) -from vllm_ascend.ops.shared_weight_layer import ( - is_hidden_layer, - reach_layer_for_shared_weight_series, - register_layer_to_shared_weight_series) -from vllm.config import get_current_vllm_config + oproj_tp_enable, + register_flashcomm2_o_shard_layer, + shared_expert_dp_enabled) + class CustomLinearOp: @@ -406,10 +406,12 @@ def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.input_size_per_partition = self.layer.input_size_per_partition - if flashcomm2_o_shared_enabled() and is_hidden_layer(self.vllm_config, register_flashcomm2_o_shard_layer(self.layer)): + if flashcomm2_o_shared_enabled() and is_hidden_layer( + self.vllm_config, register_flashcomm2_o_shard_layer( + self.layer)): from vllm_ascend.distributed.parallel_state import \ get_shared_weight_group - register_layer_to_shared_weight_series( + register_layer_to_shard_weight_series( series_name="o_proj", group=get_shared_weight_group(), layer=self.layer, @@ -491,7 +493,8 @@ def apply_impl( output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias - + + # 该层用于flashcomm2开启oshard后,在QKV matmul之前调用异步broadcast,从而实现通算掩盖 class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): @@ -515,13 +518,15 @@ def apply_impl( input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) layer_idx = extract_layer_index(self.layer.prefix) - if flashcomm2_o_shared_enabled() and is_hidden_layer(self.vllm_config, get_flashcomm2_o_shard_layer(layer_idx)): - reach_layer_for_shared_weight_series(get_flashcomm2_o_shard_layer(layer_idx)) + if flashcomm2_o_shared_enabled() and is_hidden_layer( + self.vllm_config, get_flashcomm2_o_shard_layer(layer_idx)): + reach_layer_for_shard_weight_series( + get_flashcomm2_o_shard_layer(layer_idx)) # if flashcomm2_o_shared_enable(): # from vllm_ascend.multistream.context import get_multistream_microbatch_context # if get_multistream_microbatch_context() != 0: # #TODO: 建立Oshard适配层,包括几个关键函数:初始化后处理;权重加载后处理;异步broadcast调用 - # reach_layer_for_shared_weight_series(self.o_proj) + # reach_layer_for_shard_weight_series(self.o_proj) output_parallel = self.quant_method.apply(self.layer, input_, bias) if self.gather_output: # All-gather across the partitions. diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 2451cdb654d..73d993958d9 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1020,7 +1020,8 @@ def flashcomm2_enable() -> bool: def flashcomm2_o_shared_enabled() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED -def register_flashcomm2_o_shard_layer(layer = None): + +def register_flashcomm2_o_shard_layer(layer=None): global _FLASHCOMM2_OSHARD_LAYER # layer_idx = extract_layer_index(layer.prefix) _FLASHCOMM2_OSHARD_LAYER.append(layer) @@ -1029,10 +1030,12 @@ def register_flashcomm2_o_shard_layer(layer = None): # assert _FLASHCOMM2_OSHARD_LAYER is not None, f"_FLASHCOMM2_OSHARD_LAYER is not init, please make sure that you input a valid layer parameter" return layer -def get_flashcomm2_o_shard_layer(layer_idx = 0): + +def get_flashcomm2_o_shard_layer(layer_idx=0): global _FLASHCOMM2_OSHARD_LAYER return _FLASHCOMM2_OSHARD_LAYER[layer_idx] + def get_flashcomm2_config_and_validate(ascend_config, vllm_config): flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE global_tp_size = vllm_config.parallel_config.tensor_parallel_size From 352788adf596ab9612391da63ca568f2b1f94935 Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Tue, 16 Dec 2025 20:50:10 +0800 Subject: [PATCH 4/9] extract flashcomm2_oshard_manager Signed-off-by: Levi-JQ --- vllm_ascend/attention/attention_v1.py | 1 + vllm_ascend/ops/linear_op.py | 47 +++++------- vllm_ascend/ops/utils.py | 101 ++++++++++++++++++++++++++ vllm_ascend/utils.py | 15 ---- 4 files changed, 119 insertions(+), 45 deletions(-) create mode 100644 vllm_ascend/ops/utils.py diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b34d45539f4..897019ed9d0 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -49,6 +49,7 @@ # default max value of sliding window size SWA_INT_MAX = 2147483647 +from vllm_ascend.ops.utils import flashcomm2_oshard_manager @register_backend(AttentionBackendEnum.CUSTOM, "ASCEND") diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 981e50ce2d6..b16f4440875 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -21,6 +21,7 @@ ├── CustomColumnParallelOp │ ├── MLPColumnParallelOp │ ├── SequenceColumnParallelOp +│ ├── Flashcomm2OshardQKVParallelOp └── CustomRowParallelOp │ ├── MLPRowParallelOp │ ├── OProjRowParallelOp @@ -66,9 +67,8 @@ from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, - oproj_tp_enable, - register_flashcomm2_o_shard_layer, - shared_expert_dp_enabled) + oproj_tp_enable, shared_expert_dp_enabled) +from vllm_ascend.ops.utils import flashcomm2_oshard_manager class CustomLinearOp: @@ -406,16 +406,10 @@ def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.input_size_per_partition = self.layer.input_size_per_partition - if flashcomm2_o_shared_enabled() and is_hidden_layer( - self.vllm_config, register_flashcomm2_o_shard_layer( - self.layer)): - from vllm_ascend.distributed.parallel_state import \ - get_shared_weight_group - register_layer_to_shard_weight_series( - series_name="o_proj", - group=get_shared_weight_group(), - layer=self.layer, - prefetch_step=1) + if flashcomm2_o_shared_enabled(): + flashcomm2_oshard_manager.register_layer(self.layer, + self.vllm_config, + prefetch_step=1) class MatmulAllreduceRowParallelOp(CustomRowParallelOp): @@ -495,7 +489,6 @@ def apply_impl( return output, output_bias -# 该层用于flashcomm2开启oshard后,在QKV matmul之前调用异步broadcast,从而实现通算掩盖 class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): def __init__(self, layer): @@ -505,11 +498,7 @@ def __init__(self, layer): def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - """Linear layer with column parallelism. - - Implemented multiple optimization projects for dense models, such as FlashComm and - communication-computation fusion. - """ + """Column-parallel linear with FlashComm2 OShard optimization.""" bias = self.bias if not self.skip_bias_add else None @@ -517,16 +506,12 @@ def apply_impl( assert self.quant_method is not None input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) - layer_idx = extract_layer_index(self.layer.prefix) - if flashcomm2_o_shared_enabled() and is_hidden_layer( - self.vllm_config, get_flashcomm2_o_shard_layer(layer_idx)): - reach_layer_for_shard_weight_series( - get_flashcomm2_o_shard_layer(layer_idx)) - # if flashcomm2_o_shared_enable(): - # from vllm_ascend.multistream.context import get_multistream_microbatch_context - # if get_multistream_microbatch_context() != 0: - # #TODO: 建立Oshard适配层,包括几个关键函数:初始化后处理;权重加载后处理;异步broadcast调用 - # reach_layer_for_shard_weight_series(self.o_proj) + + # Trigger async broadcast before matmul to overlap communication. + if flashcomm2_o_shared_enabled(): + flashcomm2_oshard_manager.trigger_broadcast_for_layer( + self.layer.prefix, self.vllm_config) + output_parallel = self.quant_method.apply(self.layer, input_, bias) if self.gather_output: # All-gather across the partitions. @@ -732,7 +717,8 @@ def _get_column_parallel_op( "query_key_value", # qkv linear of Bailing ] if flashcomm2_enable() and flashcomm2_o_shared_enabled(): - if "qkv_proj" in prefix: + if any(p in prefix + for p in ("qkv_proj", "conv1d", "query_key_value")): return Flashcomm2OshardQKVParallelOp(layer) for a_prefix in sp_column_prefix: if a_prefix in prefix: @@ -780,6 +766,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct): custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, MLPRowParallelOp, OProjRowParallelOp, Flashcomm2OProjRowParallelOp, + Flashcomm2OshardQKVParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp, ShardedCPRowParallelOp, ShardedCPColumnParallelOp]] = None diff --git a/vllm_ascend/ops/utils.py b/vllm_ascend/ops/utils.py new file mode 100644 index 00000000000..efdebde87e2 --- /dev/null +++ b/vllm_ascend/ops/utils.py @@ -0,0 +1,101 @@ +from vllm_ascend.ops.shared_weight_layer import ( + is_hidden_layer, post_process_after_loading_for_shared_weight_series, + reach_layer_for_shared_weight_series, + register_layer_to_shared_weight_series) +from vllm.model_executor.models.utils import extract_layer_index +from vllm_ascend.distributed.parallel_state import get_shared_weight_group +from typing import Dict, Any, Optional + + +class Flashcomm2OShardManager: + """Manages sharded layers for the FlashComm2 O-Shard feature. + + This class is implemented to centralize all logic related to Flashcomm2OShard layers. + Its main responsibilities are: + 1. Registering Attention `o_proj` layers that require O-Sharding. + 2. Storing and managing these layers in a dictionary mapping layer indices + to layer objects (`layer_index -> layer`). + 3. Providing a high-level API for external callers to use at key stages + like model initialization, computation, and weight loading. + + Attributes: + _shard_layers: A dictionary to store the registered sharded layers, + mapping a layer index (int) to its corresponding layer object. + """ + + def __init__(self): + self._shard_layers: Dict[int, Any] = {} + + def register_layer(self, + layer: Any, + vllm_config: Any, + prefetch_step: int = 1): + """Registers a layer for O-Sharding. + + This method first checks if the O-Shard feature is enabled and if the + provided layer qualifies as a target (e.g., a hidden layer). If so, + it performs two actions: + 1. Caches the layer internally in the `_shard_layers` dictionary. + 2. Calls the underlying `register_layer_to_shared_weight_series` + function to register it for communication. + + Args: + layer: The layer object to be registered. + vllm_config: The vLLM model configuration object, used to determine + if the layer is a target for sharding. + prefetch_step: The prefetch step to be used when registering the + layer to the shared weight series. + """ + # Check if the layer is a target for sharding. + if is_hidden_layer(vllm_config, layer): + layer_idx = extract_layer_index(layer.prefix) + self._shard_layers[layer_idx] = layer + + register_layer_to_shared_weight_series( + series_name="o_proj", + group=get_shared_weight_group(), + layer=layer, + prefetch_step=prefetch_step) + + def get_layer(self, layer_idx: int) -> Optional[Any]: + """Safely retrieves a registered layer by its index. + + Args: + layer_idx: The index of the layer to retrieve. + + Returns: + The layer object if found, otherwise None. + """ + return self._shard_layers.get(layer_idx) + + def trigger_broadcast_for_layer(self, layer_prefix: str, vllm_config: Any): + """Triggers a broadcast for a specific layer during model computation. + + This method is intended to be called within a layer's forward pass. + It extracts the layer index from the prefix, retrieves the corresponding + registered layer object, and then triggers the broadcast operation + if all conditions are met. + + Args: + layer_prefix: The name prefix of the current layer being computed. + vllm_config: The vLLM model configuration object. + """ + layer_idx = extract_layer_index(layer_prefix) + target_layer = self.get_layer(layer_idx) + + # Ensure the layer exists and meets the sharding criteria. + if target_layer and is_hidden_layer(vllm_config, target_layer): + reach_layer_for_shared_weight_series(target_layer) + + def post_process_after_loading(self): + """Performs post-processing on all registered layers after weight loading. + + This should be called once after the model weights have been fully loaded. + """ + # Iterate through all registered layers to preform post_process_after_loading + for layer_idx in sorted(self._shard_layers.keys()): + layer = self._shard_layers[layer_idx] + post_process_after_loading_for_shared_weight_series(layer) + + +flashcomm2_oshard_manager = Flashcomm2OShardManager() diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 73d993958d9..27235b867f6 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1021,21 +1021,6 @@ def flashcomm2_o_shared_enabled() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED -def register_flashcomm2_o_shard_layer(layer=None): - global _FLASHCOMM2_OSHARD_LAYER - # layer_idx = extract_layer_index(layer.prefix) - _FLASHCOMM2_OSHARD_LAYER.append(layer) - # if _FLASHCOMM2_OSHARD_LAYER is None: - # _FLASHCOMM2_OSHARD_LAYER = layer - # assert _FLASHCOMM2_OSHARD_LAYER is not None, f"_FLASHCOMM2_OSHARD_LAYER is not init, please make sure that you input a valid layer parameter" - return layer - - -def get_flashcomm2_o_shard_layer(layer_idx=0): - global _FLASHCOMM2_OSHARD_LAYER - return _FLASHCOMM2_OSHARD_LAYER[layer_idx] - - def get_flashcomm2_config_and_validate(ascend_config, vllm_config): flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE global_tp_size = vllm_config.parallel_config.tensor_parallel_size From 851afb3c337727abd9eaf1b693955df8f54c49b2 Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Tue, 16 Dec 2025 21:11:21 +0800 Subject: [PATCH 5/9] fix ci Signed-off-by: Levi-JQ --- vllm_ascend/ops/linear_op.py | 1 - vllm_ascend/ops/utils.py | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index b16f4440875..9d714553ac1 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -68,7 +68,6 @@ get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, shared_expert_dp_enabled) -from vllm_ascend.ops.utils import flashcomm2_oshard_manager class CustomLinearOp: diff --git a/vllm_ascend/ops/utils.py b/vllm_ascend/ops/utils.py index efdebde87e2..93ae671c420 100644 --- a/vllm_ascend/ops/utils.py +++ b/vllm_ascend/ops/utils.py @@ -1,10 +1,12 @@ +from typing import Any, Dict, Optional + +from vllm.model_executor.models.utils import extract_layer_index + +from vllm_ascend.distributed.parallel_state import get_shared_weight_group from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, reach_layer_for_shared_weight_series, register_layer_to_shared_weight_series) -from vllm.model_executor.models.utils import extract_layer_index -from vllm_ascend.distributed.parallel_state import get_shared_weight_group -from typing import Dict, Any, Optional class Flashcomm2OShardManager: @@ -92,7 +94,7 @@ def post_process_after_loading(self): This should be called once after the model weights have been fully loaded. """ - # Iterate through all registered layers to preform post_process_after_loading + # Iterate through all registered layers to perform post_process_after_loading for layer_idx in sorted(self._shard_layers.keys()): layer = self._shard_layers[layer_idx] post_process_after_loading_for_shared_weight_series(layer) From 4f8816fd0130fa9111d8561f50ecde2f51b40cfd Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Wed, 17 Dec 2025 12:14:00 +0800 Subject: [PATCH 6/9] fix review && fix Flashcomm2OshardQKVParallelOp to adapt not enable_sp() Signed-off-by: Levi-JQ --- vllm_ascend/attention/attention_v1.py | 5 ++--- vllm_ascend/ops/linear_op.py | 20 ++++++++++---------- vllm_ascend/ops/utils.py | 14 +++++++++----- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 897019ed9d0..2e6b297c07a 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -353,9 +353,8 @@ def __init__( def process_weights_after_loading(self, act_dtype: torch.dtype): super().process_weights_after_loading(act_dtype) - if flashcomm2_o_shared_enabled(): - post_process_after_loading_for_shard_weight_series( - get_flashcomm2_o_shard_layer()) + if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): + flashcomm2_oshard_manager.post_process_after_loading() def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 9d714553ac1..e820535f747 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -405,7 +405,7 @@ def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel self.input_size_per_partition = self.layer.input_size_per_partition - if flashcomm2_o_shared_enabled(): + if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): flashcomm2_oshard_manager.register_layer(self.layer, self.vllm_config, prefetch_step=1) @@ -504,15 +504,16 @@ def apply_impl( # Matrix multiply. assert self.quant_method is not None - input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + if enable_sp(): + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + input_, True) # Trigger async broadcast before matmul to overlap communication. - if flashcomm2_o_shared_enabled(): - flashcomm2_oshard_manager.trigger_broadcast_for_layer( - self.layer.prefix, self.vllm_config) + flashcomm2_oshard_manager.trigger_broadcast_for_layer( + self.layer.prefix, self.vllm_config) output_parallel = self.quant_method.apply(self.layer, input_, bias) - if self.gather_output: + if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = self.comm_group.all_gather(output_parallel) else: @@ -705,6 +706,9 @@ def _get_column_parallel_op( if "gate_up_proj" in prefix and mlp_tp_enable( ) and not is_moe_layer(prefix): return MLPColumnParallelOp(layer) + if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): + if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")): + return Flashcomm2OshardQKVParallelOp(layer) if enable_sp(): if "shared_expert" in prefix: return None @@ -715,10 +719,6 @@ def _get_column_parallel_op( "conv1d", # gated deltanet of Qwen3 Next "query_key_value", # qkv linear of Bailing ] - if flashcomm2_enable() and flashcomm2_o_shared_enabled(): - if any(p in prefix - for p in ("qkv_proj", "conv1d", "query_key_value")): - return Flashcomm2OshardQKVParallelOp(layer) for a_prefix in sp_column_prefix: if a_prefix in prefix: return SequenceColumnParallelOp(layer) diff --git a/vllm_ascend/ops/utils.py b/vllm_ascend/ops/utils.py index 93ae671c420..70c2e00b9cb 100644 --- a/vllm_ascend/ops/utils.py +++ b/vllm_ascend/ops/utils.py @@ -2,7 +2,8 @@ from vllm.model_executor.models.utils import extract_layer_index -from vllm_ascend.distributed.parallel_state import get_shared_weight_group +from vllm_ascend.distributed.parallel_state import ( + flashcomm2_enable, flashcomm2_o_shared_enabled, get_shared_weight_group) from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, reach_layer_for_shared_weight_series, @@ -28,6 +29,9 @@ class Flashcomm2OShardManager: def __init__(self): self._shard_layers: Dict[int, Any] = {} + def flashcomm2_oshard_enable(self): + return flashcomm2_enable() and flashcomm2_o_shared_enabled() + def register_layer(self, layer: Any, vllm_config: Any, @@ -94,10 +98,10 @@ def post_process_after_loading(self): This should be called once after the model weights have been fully loaded. """ - # Iterate through all registered layers to perform post_process_after_loading - for layer_idx in sorted(self._shard_layers.keys()): - layer = self._shard_layers[layer_idx] - post_process_after_loading_for_shared_weight_series(layer) + if self._shard_layers: + # Pick any layer (e.g., the first one) to trigger the shard post-processing + any_layer = next(iter(self._shard_layers.values())) + post_process_after_loading_for_shared_weight_series(any_layer) flashcomm2_oshard_manager = Flashcomm2OShardManager() From 0021b11c3ad4f59ae8a26d1a810de4a18fc5a65d Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Mon, 22 Dec 2025 12:04:46 +0800 Subject: [PATCH 7/9] add ut Signed-off-by: Levi-JQ --- .../test_offline_inference_distributed.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index c5194c63fec..f46f88b4b8e 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -157,6 +157,29 @@ def test_qwen3_moe_fc2_tp2() -> None: vllm_model.generate(example_prompts, sampling_params) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"}) +def test_fc2_oshard_for_qwen3_moe() -> None: + example_prompts = [ + "Hello, my name is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"), + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager=True, + additional_config={ + "layer_sharding": ["o_proj"] + }) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) def test_deepseek_v2_lite_fc1_tp2() -> None: example_prompts = [ From f9fedf1856d0f06908cb7c4eb27a57ea8a5bece5 Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Thu, 8 Jan 2026 16:49:30 +0800 Subject: [PATCH 8/9] adapt to refactor of oshard Signed-off-by: Levi-JQ --- .../test_offline_inference_distributed.py | 5 ++- vllm_ascend/attention/attention_v1.py | 5 ++- vllm_ascend/ops/linear_op.py | 8 +---- vllm_ascend/ops/utils.py | 35 ++++++++----------- vllm_ascend/utils.py | 7 ++-- 5 files changed, 24 insertions(+), 36 deletions(-) diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index f46f88b4b8e..890211d13fd 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -174,9 +174,8 @@ def test_fc2_oshard_for_qwen3_moe() -> None: distributed_executor_backend="mp", enable_expert_parallel=True, enforce_eager=True, - additional_config={ - "layer_sharding": ["o_proj"] - }) as vllm_model: + additional_config={"layer_sharding": + ["o_proj"]}) as vllm_model: vllm_model.generate(example_prompts, sampling_params) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 2e6b297c07a..fd4797ca4e4 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -43,13 +43,12 @@ from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces) +from vllm_ascend.ops.utils import flashcomm2_oshard_manager from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, - weak_ref_tensors, flashcomm2_o_shared_enabled, get_flashcomm2_o_shard_layer) -from vllm_ascend.ops.layer_shard_linear import post_process_after_loading_for_shard_weight_series + weak_ref_tensors) # default max value of sliding window size SWA_INT_MAX = 2147483647 -from vllm_ascend.ops.utils import flashcomm2_oshard_manager @register_backend(AttentionBackendEnum.CUSTOM, "ASCEND") diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index e820535f747..c57a264e94f 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -49,7 +49,6 @@ from torch import nn from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from vllm.config import get_current_vllm_config from vllm.distributed import (split_tensor_along_last_dim, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) @@ -57,8 +56,6 @@ from vllm.forward_context import get_forward_context from vllm_ascend import envs as envs_ascend -from vllm.model_executor.models.utils import extract_layer_index - from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_flashcomm2_otp_group, @@ -287,7 +284,6 @@ def __init__(self, layer): get_tp_group().world_size) self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() self.layer._quant_comm_config = {} - self.vllm_config = get_current_vllm_config() @property def comm_group(self): @@ -407,7 +403,6 @@ def update_attrs(self): self.input_size_per_partition = self.layer.input_size_per_partition if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): flashcomm2_oshard_manager.register_layer(self.layer, - self.vllm_config, prefetch_step=1) @@ -492,7 +487,6 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp): def __init__(self, layer): super().__init__(layer) - self.vllm_config = get_current_vllm_config() def apply_impl( self, input_: torch.Tensor @@ -510,7 +504,7 @@ def apply_impl( # Trigger async broadcast before matmul to overlap communication. flashcomm2_oshard_manager.trigger_broadcast_for_layer( - self.layer.prefix, self.vllm_config) + self.layer.prefix) output_parallel = self.quant_method.apply(self.layer, input_, bias) if self.gather_output and self.tp_size > 1: diff --git a/vllm_ascend/ops/utils.py b/vllm_ascend/ops/utils.py index 70c2e00b9cb..d95748e6554 100644 --- a/vllm_ascend/ops/utils.py +++ b/vllm_ascend/ops/utils.py @@ -2,12 +2,11 @@ from vllm.model_executor.models.utils import extract_layer_index -from vllm_ascend.distributed.parallel_state import ( - flashcomm2_enable, flashcomm2_o_shared_enabled, get_shared_weight_group) -from vllm_ascend.ops.shared_weight_layer import ( - is_hidden_layer, post_process_after_loading_for_shared_weight_series, - reach_layer_for_shared_weight_series, - register_layer_to_shared_weight_series) +from vllm_ascend.distributed.parallel_state import get_shard_weight_group +from vllm_ascend.ops.layer_shard_linear import ( + is_hidden_layer, post_process_after_loading_for_shard_weight_series, + reach_layer_for_shard_weight_series, register_layer_to_shard_weight_series) +from vllm_ascend.utils import flashcomm2_enable, o_shard_enable class Flashcomm2OShardManager: @@ -30,12 +29,9 @@ def __init__(self): self._shard_layers: Dict[int, Any] = {} def flashcomm2_oshard_enable(self): - return flashcomm2_enable() and flashcomm2_o_shared_enabled() + return flashcomm2_enable() and o_shard_enable() - def register_layer(self, - layer: Any, - vllm_config: Any, - prefetch_step: int = 1): + def register_layer(self, layer: Any, prefetch_step: int = 1): """Registers a layer for O-Sharding. This method first checks if the O-Shard feature is enabled and if the @@ -47,19 +43,17 @@ def register_layer(self, Args: layer: The layer object to be registered. - vllm_config: The vLLM model configuration object, used to determine - if the layer is a target for sharding. prefetch_step: The prefetch step to be used when registering the layer to the shared weight series. """ # Check if the layer is a target for sharding. - if is_hidden_layer(vllm_config, layer): + if is_hidden_layer(layer): layer_idx = extract_layer_index(layer.prefix) self._shard_layers[layer_idx] = layer - register_layer_to_shared_weight_series( + register_layer_to_shard_weight_series( series_name="o_proj", - group=get_shared_weight_group(), + group=get_shard_weight_group(), layer=layer, prefetch_step=prefetch_step) @@ -74,7 +68,7 @@ def get_layer(self, layer_idx: int) -> Optional[Any]: """ return self._shard_layers.get(layer_idx) - def trigger_broadcast_for_layer(self, layer_prefix: str, vllm_config: Any): + def trigger_broadcast_for_layer(self, layer_prefix: str): """Triggers a broadcast for a specific layer during model computation. This method is intended to be called within a layer's forward pass. @@ -84,14 +78,13 @@ def trigger_broadcast_for_layer(self, layer_prefix: str, vllm_config: Any): Args: layer_prefix: The name prefix of the current layer being computed. - vllm_config: The vLLM model configuration object. """ layer_idx = extract_layer_index(layer_prefix) target_layer = self.get_layer(layer_idx) # Ensure the layer exists and meets the sharding criteria. - if target_layer and is_hidden_layer(vllm_config, target_layer): - reach_layer_for_shared_weight_series(target_layer) + if target_layer and is_hidden_layer(target_layer): + reach_layer_for_shard_weight_series(target_layer) def post_process_after_loading(self): """Performs post-processing on all registered layers after weight loading. @@ -101,7 +94,7 @@ def post_process_after_loading(self): if self._shard_layers: # Pick any layer (e.g., the first one) to trigger the shard post-processing any_layer = next(iter(self._shard_layers.values())) - post_process_after_loading_for_shared_weight_series(any_layer) + post_process_after_loading_for_shard_weight_series(any_layer) flashcomm2_oshard_manager = Flashcomm2OShardManager() diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 27235b867f6..0000d696498 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1017,8 +1017,11 @@ def flashcomm2_enable() -> bool: return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0 -def flashcomm2_o_shared_enabled() -> bool: - return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED +def o_shard_enable() -> bool: + layer_sharding = get_ascend_config().layer_sharding + if layer_sharding is None: + return False + return "o_proj" in layer_sharding def get_flashcomm2_config_and_validate(ascend_config, vllm_config): From 6c20df94c2d072f8a7fc66ec5b784308b67ee04e Mon Sep 17 00:00:00 2001 From: Levi-JQ Date: Fri, 9 Jan 2026 17:18:17 +0800 Subject: [PATCH 9/9] fix comment Signed-off-by: Levi-JQ --- .../test_offline_inference_distributed.py | 19 ++++++++++--------- vllm_ascend/attention/attention_v1.py | 2 +- ...{utils.py => flashcomm2_oshard_manager.py} | 0 vllm_ascend/ops/linear_op.py | 3 ++- 4 files changed, 13 insertions(+), 11 deletions(-) rename vllm_ascend/ops/{utils.py => flashcomm2_oshard_manager.py} (100%) diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index 890211d13fd..d06dece3446 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -159,7 +159,7 @@ def test_qwen3_moe_fc2_tp2() -> None: @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"}) -def test_fc2_oshard_for_qwen3_moe() -> None: +def test_qwen3_moe_fc2_oshard_tp2() -> None: example_prompts = [ "Hello, my name is", ] @@ -168,14 +168,15 @@ def test_fc2_oshard_for_qwen3_moe() -> None: top_k=50, top_p=0.9) - with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"), - dtype="auto", - tensor_parallel_size=2, - distributed_executor_backend="mp", - enable_expert_parallel=True, - enforce_eager=True, - additional_config={"layer_sharding": - ["o_proj"]}) as vllm_model: + with VllmRunner( + snapshot_download("Qwen/Qwen3-30B-A3B"), + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager= + True, # TODO(Levi-JQ): support graph mode for fc2 in Qwen + additional_config={"layer_sharding": ["o_proj"]}) as vllm_model: vllm_model.generate(example_prompts, sampling_params) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index fd4797ca4e4..80ebf17dfb4 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -43,7 +43,7 @@ from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces) -from vllm_ascend.ops.utils import flashcomm2_oshard_manager +from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, weak_ref_tensors) diff --git a/vllm_ascend/ops/utils.py b/vllm_ascend/ops/flashcomm2_oshard_manager.py similarity index 100% rename from vllm_ascend/ops/utils.py rename to vllm_ascend/ops/flashcomm2_oshard_manager.py diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index c57a264e94f..4a90ea90498 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -61,6 +61,7 @@ get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) +from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, @@ -694,7 +695,7 @@ def apply_impl( def _get_column_parallel_op( prefix, layer ) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, - ShardedCPColumnParallelOp]]: + ShardedCPColumnParallelOp, Flashcomm2OshardQKVParallelOp]]: if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix): return ShardedCPColumnParallelOp(layer) if "gate_up_proj" in prefix and mlp_tp_enable(