Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/ut/distributed/test_parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.mlp_tensor_parallel_size = 2
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
mock_ascend_config.pd_tp_ratio = 2
mock_ascend_config.num_head_replica = 0
Expand Down
28 changes: 7 additions & 21 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,11 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="mc2")

# Initialize specialized tensor parallel (TP) process groups for fine-grained model parallelism
# on Ascend hardware. This enables independent TP configurations for three critical components:

# 1. ** LM Head **:
# The final linear layer that maps hidden states to vocabulary logits.
# Controlled by `lmhead_tensor_parallel_size`.

# 2. ** o_proj **:
# The output projection in attention blocks (e.g., in Multi-Head Attention).
# Controlled by `oproj_tensor_parallel_size`.

# 3. ** Embedding **:
# The token embedding table at the input and/or output of the model.
# Controlled by `embedding_tensor_parallel_size`.

# 4. ** MLP **:
# The feed-forward network layers within transformer blocks.
# Controlled by `mlp_tensor_parallel_size`.

# Initialize fine-grained TP process groups on Ascend for four components:
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
# 3. Embedding: The token embedding table at the input of the model (`embedding_tensor_parallel_size`)
# 4. MLP: feed-forward network in transformer blocks (`mlp_tensor_parallel_size`)
_group_cache = {}

def _create_or_get_group(group_size: int,
Expand Down Expand Up @@ -149,9 +135,9 @@ def _create_or_get_group(group_size: int,
embedding_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size
mlp_tp_size = get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size
).finegrained_tp_config.mlp_tensor_parallel_size

global _OTP, _LMTP, _EMBED_TP
global _OTP, _LMTP, _EMBED_TP, _MLP_TP

if otp_size > 0:
_OTP = _create_or_get_group(otp_size, "otp")
Expand Down
31 changes: 29 additions & 2 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
"""

import re
from functools import lru_cache
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -605,7 +607,8 @@ def update_attrs(self):
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
if mlp_tp_enable() and "gate_up_proj" in prefix:
if "gate_up_proj" in prefix and mlp_tp_enable(
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
Expand All @@ -629,7 +632,7 @@ def _get_row_parallel_op(
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable():
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
Expand Down Expand Up @@ -681,3 +684,27 @@ def get_replicated_op(disable_tp, prefix,
return None

return CustomReplicatedOp(layer)


def is_moe_layer(prefix: str) -> bool:
Comment thread
zzhx1 marked this conversation as resolved.

@lru_cache(maxsize=1)
def get_moe_params():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_config
n_routed_experts = getattr(config, 'n_routed_experts', 0)
first_k_dense_replace = getattr(config, 'first_k_dense_replace',
float('inf'))
moe_layer_freq = getattr(config, 'moe_layer_freq', 1)
return n_routed_experts, first_k_dense_replace, moe_layer_freq

match = re.search(r'layers\.(\d+)\.', prefix)
if match is None:
return False
layer_idx = int(match.group(1))

n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params()

return (n_routed_experts is not None and layer_idx >= first_k_dense_replace
and layer_idx % moe_layer_freq == 0)
Loading