diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 7fd22629943..1e7a9312a8f 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -41,6 +41,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2 mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2 mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2 + mock_ascend_config.finegrained_tp_config.mlp_tensor_parallel_size = 2 mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2 mock_ascend_config.pd_tp_ratio = 2 mock_ascend_config.num_head_replica = 0 diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 1709d150bde..e886a31113e 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -96,25 +96,11 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="mc2") - # Initialize specialized tensor parallel (TP) process groups for fine-grained model parallelism - # on Ascend hardware. This enables independent TP configurations for three critical components: - - # 1. ** LM Head **: - # The final linear layer that maps hidden states to vocabulary logits. - # Controlled by `lmhead_tensor_parallel_size`. - - # 2. ** o_proj **: - # The output projection in attention blocks (e.g., in Multi-Head Attention). - # Controlled by `oproj_tensor_parallel_size`. - - # 3. ** Embedding **: - # The token embedding table at the input and/or output of the model. - # Controlled by `embedding_tensor_parallel_size`. - - # 4. ** MLP **: - # The feed-forward network layers within transformer blocks. - # Controlled by `mlp_tensor_parallel_size`. - + # Initialize fine-grained TP process groups on Ascend for four components: + # 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`) + # 2. O Proj: attention output projection (`oproj_tensor_parallel_size`) + # 3. Embedding: The token embedding table at the input of the model (`embedding_tensor_parallel_size`) + # 4. MLP: feed-forward network in transformer blocks (`mlp_tensor_parallel_size`) _group_cache = {} def _create_or_get_group(group_size: int, @@ -149,9 +135,9 @@ def _create_or_get_group(group_size: int, embedding_tp_size = get_ascend_config( ).finegrained_tp_config.embedding_tensor_parallel_size mlp_tp_size = get_ascend_config( - ).finegrained_tp_config.embedding_tensor_parallel_size + ).finegrained_tp_config.mlp_tensor_parallel_size - global _OTP, _LMTP, _EMBED_TP + global _OTP, _LMTP, _EMBED_TP, _MLP_TP if otp_size > 0: _OTP = _create_or_get_group(otp_size, "otp") diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 65affe8bcef..27310ffd605 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -36,6 +36,8 @@ Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op. """ +import re +from functools import lru_cache from typing import Optional, Union import torch @@ -605,7 +607,8 @@ def update_attrs(self): def _get_column_parallel_op( prefix, layer ) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: - if mlp_tp_enable() and "gate_up_proj" in prefix: + if "gate_up_proj" in prefix and mlp_tp_enable( + ) and not is_moe_layer(prefix): return MLPColumnParallelOp(layer) if enable_sp(): if "shared_expert" in prefix: @@ -629,7 +632,7 @@ def _get_row_parallel_op( ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]: - if "down_proj" in prefix and mlp_tp_enable(): + if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): return OProjRowParallelOp(layer) @@ -681,3 +684,27 @@ def get_replicated_op(disable_tp, prefix, return None return CustomReplicatedOp(layer) + + +def is_moe_layer(prefix: str) -> bool: + + @lru_cache(maxsize=1) + def get_moe_params(): + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + config = vllm_config.model_config.hf_config + n_routed_experts = getattr(config, 'n_routed_experts', 0) + first_k_dense_replace = getattr(config, 'first_k_dense_replace', + float('inf')) + moe_layer_freq = getattr(config, 'moe_layer_freq', 1) + return n_routed_experts, first_k_dense_replace, moe_layer_freq + + match = re.search(r'layers\.(\d+)\.', prefix) + if match is None: + return False + layer_idx = int(match.group(1)) + + n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params() + + return (n_routed_experts is not None and layer_idx >= first_k_dense_replace + and layer_idx % moe_layer_freq == 0)