From fbb52ffb1aca30d24e641ae91c825291e70d5614 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Sun, 30 Nov 2025 03:21:53 +0800 Subject: [PATCH 1/4] [Feat] Add Custom embedding tensor parallel Signed-off-by: zzhx1 Co-authored-by: chenxiao --- .../configuration/additional_config.md | 20 +- tests/ut/distributed/test_parallel_state.py | 15 +- tests/ut/ops/test_linear.py | 19 +- tests/ut/ops/test_vocab_parallel_embedding.py | 29 ++- vllm_ascend/ascend_config.py | 86 ++++-- vllm_ascend/distributed/parallel_state.py | 245 ++++++++++-------- vllm_ascend/envs.py | 4 - vllm_ascend/ops/vocab_parallel_embedding.py | 34 ++- vllm_ascend/utils.py | 11 +- 9 files changed, 298 insertions(+), 165 deletions(-) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index c975c6b1e48..7d17df8b432 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -27,14 +27,13 @@ The following table lists additional configuration options available in vLLM Asc | Name | Type | Default | Description | |-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------| | `xlite_graph_config` | dict | `{}` | Configuration options for xlite graph mode | +| `module_tp_config` | dict | `{}` | Configuration options for module tensor parallelism | | `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch | | `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for an MoE model, an expert map path needs to be passed in. | | `kv_cache_dtype` | str | `None` | When using the KV cache quantization method, KV cache dtype needs to be set, currently only int8 is supported. | | `enable_shared_expert_dp` | bool | `False` | When the expert is shared in DP, it delivers better performance but consumes more memory. Currently only DeepSeek series models are supported. | -| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | -| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. | -| `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multistream shared expert. This option only takes effect on MoE models with shared experts. | +| `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multistream shared expert. This option only takes effects on MoE models with shared experts. | | `dynamic_eplb` | bool | `False` | Whether to enable dynamic EPLB. | | `num_iterations_eplb_update` | int | `400` | Forward iterations when EPLB begins. | | `gate_eplb` | bool | `False` | Whether to enable EPLB only once. | @@ -58,6 +57,15 @@ The details of each configuration option are as follows: | `enabled` | bool | `False` | Whether to enable weight prefetch. | | `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weight. | +**module_tp_config** + +| Name | Type | Default | Description | +| ---- | ---- | ------- | ----------- | +| `lmhead_tensor_parallel_size` | int | `0` | The custom tensor parallel size of lmhead. | +| `oproj_tensor_parallel_size` | int | `0` | The custom tensor parallel size of oproj. | +| `embedding_tensor_parallel_size` | int | `0` | The custom tensor parallel size of embedding. | +| `mlp_tensor_parallel_size` | int | `0` | The custom tensor parallel size of mlp. | + ### Example An example of additional configuration is as follows: @@ -76,6 +84,12 @@ An example of additional configuration is as follows: } }, }, + "module_tp_config": { + "lmhead_tensor_parallel_size": 8, + "oproj_tensor_parallel_size": 8, + "embedding_tensor_parallel_size": 8, + "mlp_tensor_parallel_size": 8, + }, "multistream_overlap_shared_expert": True, "refresh": False, } diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index c69f44490b7..4c0300a8550 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -12,15 +12,17 @@ @pytest.fixture def parallel_config(): - return ParallelConfig(data_parallel_size=2, - tensor_parallel_size=2, - pipeline_parallel_size=2) + return ParallelConfig( + data_parallel_size=2, + tensor_parallel_size=4, + pipeline_parallel_size=2, + ) @pytest.fixture def mock_distributed(): with patch('torch.distributed.is_initialized', return_value=True), \ - patch('torch.distributed.get_world_size', return_value=8), \ + patch('torch.distributed.get_world_size', return_value=16), \ patch('torch.distributed.get_backend', return_value='nccl'), \ patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \ patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \ @@ -36,8 +38,9 @@ def mock_distributed(): def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_ascend_config = MagicMock() - mock_ascend_config.lmhead_tensor_parallel_size = 2 - mock_ascend_config.oproj_tensor_parallel_size = 2 + mock_ascend_config.module_tp_config.lmhead_tensor_parallel_size = 2 + mock_ascend_config.module_tp_config.oproj_tensor_parallel_size = 2 + mock_ascend_config.module_tp_config.embedding_tensor_parallel_size = 2 mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2 mock_ascend_config.pd_tp_ratio = 2 mock_ascend_config.num_head_replica = 0 diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index c31033e67dd..8ac2e4719ad 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -1,4 +1,3 @@ -import os import unittest from unittest import mock from unittest.mock import MagicMock, patch @@ -26,7 +25,8 @@ def setUp(self): parallel_state._OTP = self.mock_group self.mock_ascend_config = MagicMock() - self.mock_ascend_config.oproj_tensor_parallel_size = 2 + self.mock_ascend_config.module_tp_config.oproj_tensor_parallel_size = 2 + self.mock_ascend_config.module_tp_config.mlp_tensor_parallel_size = 2 self.patches = [ patch("vllm_ascend.ascend_config.get_ascend_config", @@ -81,7 +81,11 @@ def test_process_weights_after_loading_disable_nz(self, mock_format_cast, class TestAscendRowParallelLinear(BaseLinearTest): def test_mlp_optimize(self): - os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" + + ascend_config._ASCEND_CONFIG = MagicMock() + ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False + ascend_config._ASCEND_CONFIG.module_tp_config.mlp_tensor_parallel_size = 2 + ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False linear = AscendRowParallelLinear( input_size=16, @@ -98,8 +102,9 @@ def test_oproj_tp(self): config._current_vllm_config = MagicMock() ascend_config._ASCEND_CONFIG = MagicMock() - ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2 ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False + ascend_config._ASCEND_CONFIG.module_tp_config.oproj_tensor_parallel_size = 2 + ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False linear = AscendRowParallelLinear( input_size=16, @@ -115,7 +120,11 @@ def test_oproj_tp(self): class TestAscendMergedColumnParallelLinear(BaseLinearTest): def test_merged_mlp_tp_init(self): - os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1" + + ascend_config._ASCEND_CONFIG = MagicMock() + ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False + ascend_config._ASCEND_CONFIG.module_tp_config.mlp_tensor_parallel_size = 2 + ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False linear = AscendMergedColumnParallelLinear( input_size=16, diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index 531df28140f..15963ba391f 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -14,11 +14,12 @@ # Adapted from vllm/tests/lora/test_layers.py import unittest +from unittest import mock from unittest.mock import MagicMock, patch import torch -from vllm_ascend.ascend_config import init_ascend_config +from vllm_ascend.distributed import parallel_state from vllm_ascend.ops.vocab_parallel_embedding import ( AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) @@ -32,9 +33,33 @@ def setUp(self): self.embedding_dim = 10 self.org_num_embeddings = 40 self.padding_size = 8 + + self.mock_group = mock.MagicMock() + self.mock_group.world_size = 2 + self.mock_group.rank_in_group = 0 + + parallel_state._MLP_TP = self.mock_group + parallel_state._OTP = self.mock_group + mock_vllm_config = MagicMock() mock_vllm_config.additional_config = {} - init_ascend_config(mock_vllm_config) + self.mock_ascend_config = MagicMock() + self.mock_ascend_config.module_tp_config.lmhead_tensor_parallel_size = 2 + self.mock_ascend_config.module_tp_config.embedding_tensor_parallel_size = 2 + + self.patches = [ + patch("vllm_ascend.utils.get_ascend_config", + return_value=self.mock_ascend_config), + patch("vllm_ascend.distributed.parallel_state.get_lmhead_tp_group", + return_value=self.mock_group), + patch( + "vllm.distributed.parallel_state.get_tp_group", + return_value=self.mock_group, + ), + ] + + for p in self.patches: + p.start() def _create_layer(self): # Patch methods and dependencies for VocabParallelEmbedding diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index e1eaad1eda4..4436e13681a 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -67,6 +67,9 @@ def __init__(self, vllm_config): self.ascend_compilation_config = AscendCompilationConfig( **ascend_compilation_config) + module_tp_config = additional_config.get("module_tp_config", {}) + self.module_tp_config = ModuleTPConfig(module_tp_config, vllm_config) + # Dump / PrecisionDebugger configuration dump_config_path = additional_config.get("dump_config", None) self.dump_config = DumpConfig(dump_config_path) @@ -103,34 +106,6 @@ def __init__(self, vllm_config): "multistream_overlap_shared_expert", False) self.recompute_scheduler_enable = additional_config.get( "recompute_scheduler_enable", False) - self.lmhead_tensor_parallel_size = additional_config.get( - "lmhead_tensor_parallel_size", None) - if self.lmhead_tensor_parallel_size is not None: - logger.info( - f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario" - ) - if vllm_config.parallel_config.tensor_parallel_size != 1: - raise AssertionError( - "lmhead_tensor_parallel_size is only supported in the pure DP scenario" - ) - self.oproj_tensor_parallel_size = additional_config.get( - "oproj_tensor_parallel_size", None) - if self.oproj_tensor_parallel_size is not None: - logger.info( - f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario" - ) - if vllm_config.parallel_config.tensor_parallel_size != 1: - raise AssertionError( - "oproj_tensor_parallel_size is only supported in the pure DP scenario" - ) - if vllm_config.model_config.enforce_eager is True: - raise AssertionError( - "oproj_tensor_parallel_size is only supported in graph mode" - ) - if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: - raise AssertionError( - "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." - ) self.enable_cpu_binding = additional_config.get( "enable_cpu_binding", False) @@ -181,6 +156,61 @@ def __init__(self, vllm_config): kv_cfg._engine_id_patched = True +class ModuleTPConfig: + """ + Configuration Object for module_tp_config from additional_config + """ + + def __init__(self, module_tp_config: dict, vllm_config): + self.oproj_tensor_parallel_size = module_tp_config.get( + "oproj_tensor_parallel_size", 0) + self.lmhead_tensor_parallel_size = module_tp_config.get( + "lmhead_tensor_parallel_size", 0) + self.embedding_tensor_parallel_size = module_tp_config.get( + "embedding_tensor_parallel_size", 0) + self.mlp_tensor_parallel_size = module_tp_config.get( + "mlp_tensor_parallel_size", 0) + + enabled_configs = [] + if self.oproj_tensor_parallel_size > 0: + enabled_configs.append( + f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}" + ) + # dummy_run does not run the entire attention module in eager mode,, so the o_proj tp split can only be used in graph mode. + if vllm_config.model_config.enforce_eager is True: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in graph mode" + ) + if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: + raise AssertionError( + "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." + ) + if self.lmhead_tensor_parallel_size > 0: + enabled_configs.append( + f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}" + ) + if self.embedding_tensor_parallel_size > 0: + enabled_configs.append( + f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}" + ) + if self.mlp_tensor_parallel_size > 0: + enabled_configs.append( + f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}") + module_tp_sizes = [ + self.oproj_tensor_parallel_size, + self.lmhead_tensor_parallel_size, + self.embedding_tensor_parallel_size, + self.mlp_tensor_parallel_size, + ] + for module_tp_size in module_tp_sizes: + if module_tp_size > 0 and vllm_config.parallel_config.data_parallel_size % module_tp_size != 0: + raise AssertionError( + "module tp sizes must divide data_parallel_size") + if any(size > 0 for size in module_tp_sizes) and enabled_configs: + logger.info( + f"module_tp_config enabled: {', '.join(enabled_configs)}") + + class AscendCompilationConfig: """ Configuration for controlling the behavior of Ascend graph optimization. diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index c9bff649156..3eff12b25a1 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -7,69 +7,27 @@ get_world_group, init_model_parallel_group) -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import (enable_sp, flashcomm2_enable, flashcomm2_o_shared_enabled) # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None + +# Module specific tensor parallel groups _MLP_TP: Optional[GroupCoordinator] = None _OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None -_P_TP: Optional[GroupCoordinator] = None +_EMBED_TP: Optional[GroupCoordinator] = None + +# flashcomm2 specific groups _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None _FLASHCOMM2_ODP: Optional[GroupCoordinator] = None -_SHARED_WEIGHT: Optional[GroupCoordinator] = None - - -def get_mc2_group() -> GroupCoordinator: - assert _MC2 is not None, ("mc2 group is not initialized") - return _MC2 - - -def get_otp_group() -> GroupCoordinator: - assert _OTP is not None, ( - "output tensor parallel group is not initialized") - return _OTP - - -def get_lmhead_tp_group() -> GroupCoordinator: - assert _LMTP is not None, ( - "lm head tensor parallel group is not initialized") - return _LMTP - - -def get_flashcomm2_otp_group() -> GroupCoordinator: - return _FLASHCOMM2_OTP - - -def get_flashcomm2_odp_group() -> GroupCoordinator: - assert _FLASHCOMM2_ODP is not None, ( - "output data parallel group for flashcomm2 is not initialized") - return _FLASHCOMM2_ODP - - -def get_shared_weight_group() -> GroupCoordinator: - assert _SHARED_WEIGHT is not None, ( - "output shared weight parallel group for flashcomm2 is not initialized" - ) - return _SHARED_WEIGHT - - -def get_mlp_tp_group() -> GroupCoordinator: - assert _MLP_TP is not None, ("mlp group is not initialized") - return _MLP_TP - - -def get_p_tp_group() -> GroupCoordinator: - assert _P_TP is not None, ( - "distributed prefill tensor parallel group is not initialized") - return _P_TP +# shared_weight across rank groups +_SHARED_WEIGHT: Optional[GroupCoordinator] = None -def model_parallel_initialized(): - return (_MC2 is not None) +_P_TP: Optional[GroupCoordinator] = None def init_ascend_model_parallel(parallel_config: ParallelConfig, ): @@ -79,14 +37,16 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): world_size = torch.distributed.get_world_size() backend = torch.distributed.get_backend(get_world_group().device_group) vllm_config = get_current_vllm_config() + global_tp_size = parallel_config.tensor_parallel_size + global_dp_size = parallel_config.data_parallel_size + global_pp_size = parallel_config.pipeline_parallel_size # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). all_ranks = torch.arange(world_size).reshape( - -1, parallel_config.data_parallel_size * - parallel_config.prefill_context_parallel_size * - parallel_config.tensor_parallel_size) + -1, global_dp_size * parallel_config.prefill_context_parallel_size * + global_tp_size) pd_tp_ratio = get_ascend_config().pd_tp_ratio pd_head_ratio = get_ascend_config().pd_head_ratio @@ -98,13 +58,13 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): if pd_head_ratio > 1 and get_current_vllm_config( ).kv_transfer_config.is_kv_producer: num_head_replica = get_ascend_config().num_head_replica - remote_tp_size = parallel_config.tensor_parallel_size // pd_tp_ratio + remote_tp_size = global_tp_size // pd_tp_ratio if num_head_replica <= 1: group_ranks = all_ranks.view( -1, prefill_tensor_model_parallel_size).unbind(0) else: group_ranks = all_ranks.clone().view( - parallel_config.data_parallel_size, -1, + global_dp_size, -1, num_head_replica) # [DP_size, num_head, num_head_replica] group_ranks = group_ranks.permute(0, 2, 1) group_ranks = group_ranks.reshape( @@ -112,8 +72,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): group_ranks.size(-1)) # [DP_size * num_head_replica, num_head] alltoall_group_size = group_ranks.size(-1) // remote_tp_size group_ranks = group_ranks.unsqueeze(-1).view( - parallel_config.data_parallel_size, num_head_replica, -1, - alltoall_group_size + global_dp_size, num_head_replica, -1, alltoall_group_size ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0) @@ -135,54 +94,71 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): get_world_group().local_rank, backend, group_name="mc2") - if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: - global _MLP_TP - assert _MLP_TP is None, ( - "mlp tensor model parallel group is already initialized") - - mlp_tp = parallel_config.data_parallel_size - - all_ranks_mlp_head = torch.arange(world_size).reshape( - -1, mlp_tp, parallel_config.pipeline_parallel_size, 1) # noqa - group_ranks = all_ranks_mlp_head.view(-1, mlp_tp).unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - # message queue broadcaster is only used in tensor model parallel group - _MLP_TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="mlp_tp") - - # If oproj tensor parallel size is set, we will create a group for it. - otp_size = get_ascend_config().oproj_tensor_parallel_size - if otp_size is not None: - group_ranks = [] - global _OTP - num_oproj_tensor_parallel_groups: int = (world_size // otp_size) - for i in range(num_oproj_tensor_parallel_groups): - ranks = list(range(i * otp_size, (i + 1) * otp_size)) - group_ranks.append(ranks) - _OTP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="otp") - - lmhead_tensor_parallel_size = get_ascend_config( - ).lmhead_tensor_parallel_size - if lmhead_tensor_parallel_size is not None: - group_ranks = [] - global _LMTP - num_lmhead_tensor_parallel_groups: int = (world_size // - lmhead_tensor_parallel_size) - for i in range(num_lmhead_tensor_parallel_groups): - ranks = list( - range(i * lmhead_tensor_parallel_size, - (i + 1) * lmhead_tensor_parallel_size)) - group_ranks.append(ranks) - _LMTP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="lmheadtp") + # Initialize specialized tensor parallel (TP) process groups for fine-grained model parallelism + # on Ascend hardware. This enables independent TP configurations for three critical components: + + # 1. ** LM Head **: + # The final linear layer that maps hidden states to vocabulary logits. + # Controlled by `lmhead_tensor_parallel_size`. + + # 2. ** o_proj **: + # The output projection in attention blocks (e.g., in Multi-Head Attention). + # Controlled by `oproj_tensor_parallel_size`. + + # 3. ** Embedding **: + # The token embedding table at the input and/or output of the model. + # Controlled by `embedding_tensor_parallel_size`. + + # 4. ** MLP **: + # The feed-forward network layers within transformer blocks. + # Controlled by `mlp_tensor_parallel_size`. + + _group_cache = {} + + def _create_or_get_group(group_size: int, + group_name: str) -> GroupCoordinator: + if group_size is None: + return None + if group_size not in _group_cache: + + rank_grid = torch.arange(world_size).reshape( + global_pp_size, global_dp_size, global_tp_size) + num_chunks = global_dp_size // group_size + group_ranks = [] + for pp_idx in range(global_pp_size): + stage_ranks = rank_grid[pp_idx] # (dp, tp) + for chunk in range(num_chunks): + for tp_idx in range(global_tp_size): + group = stage_ranks[chunk * group_size:(chunk + 1) * + group_size, tp_idx].tolist() + group_ranks.append(group) + pg = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name=group_name) + _group_cache[group_size] = pg + + return _group_cache[group_size] + + otp_size = get_ascend_config().module_tp_config.oproj_tensor_parallel_size + lmhead_tp_size = get_ascend_config( + ).module_tp_config.lmhead_tensor_parallel_size + embedding_tp_size = get_ascend_config( + ).module_tp_config.embedding_tensor_parallel_size + mlp_tp_size = get_ascend_config( + ).module_tp_config.embedding_tensor_parallel_size + + global _OTP, _LMTP, _EMBED_TP + + if otp_size > 0: + _OTP = _create_or_get_group(otp_size, "otp") + if lmhead_tp_size > 0: + _LMTP = _create_or_get_group(lmhead_tp_size, "lmheadtp") + if embedding_tp_size > 0: + _EMBED_TP = _create_or_get_group(embedding_tp_size, "emtp") + if mlp_tp_size > 0: + _MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp") def _create_shared_weight_group(group_name: str) -> GroupCoordinator: #This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference @@ -265,14 +241,58 @@ def _create_shared_weight_group(group_name: str) -> GroupCoordinator: _SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared") -def get_mlp_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - return get_mlp_tp_group().world_size +def model_parallel_initialized(): + return (_MC2 is not None) + + +def get_mc2_group() -> GroupCoordinator: + assert _MC2 is not None, ("mc2 group is not initialized") + return _MC2 + + +def get_mlp_tp_group() -> GroupCoordinator: + assert _MLP_TP is not None, ("mlp group is not initialized") + return _MLP_TP + + +def get_otp_group() -> GroupCoordinator: + assert _OTP is not None, ( + "output tensor parallel group is not initialized") + return _OTP + + +def get_lmhead_tp_group() -> GroupCoordinator: + assert _LMTP is not None, ( + "lm head tensor parallel group is not initialized") + return _LMTP + + +def get_embed_tp_group() -> GroupCoordinator: + assert _EMBED_TP is not None, ("emtp group is not initialized") + return _EMBED_TP + + +def get_flashcomm2_otp_group() -> GroupCoordinator: + return _FLASHCOMM2_OTP + + +def get_flashcomm2_odp_group() -> GroupCoordinator: + assert _FLASHCOMM2_ODP is not None, ( + "output data parallel group for flashcomm2 is not initialized") + return _FLASHCOMM2_ODP + + +def get_shared_weight_group() -> GroupCoordinator: + assert _SHARED_WEIGHT is not None, ( + "output shared weight parallel group for flashcomm2 is not initialized" + ) + return _SHARED_WEIGHT -def get_mlp_tensor_model_parallel_rank(): - """Return world size for the tensor model parallel group.""" - return get_mlp_tp_group().rank_in_group +def get_p_tp_group() -> GroupCoordinator: + assert _P_TP is not None, ( + "distributed prefill tensor parallel group is not initialized") + return _P_TP def destroy_ascend_model_parallel(): @@ -291,6 +311,11 @@ def destroy_ascend_model_parallel(): _LMTP.destroy() _LMTP = None + global _EMBED_TP + if _EMBED_TP: + _EMBED_TP.destroy() + _EMBED_TP = None + global _OTP if _OTP: _OTP.destroy() diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 7e0480c9b75..5e926b11c7f 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -118,10 +118,6 @@ # However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. "VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))), - # Whether to enable mlp optimize when tensor parallel is enabled. - # this feature in eager mode will get better performance. - "VLLM_ASCEND_ENABLE_MLP_OPTIMIZE": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))), # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index a89c228dd51..8fb117249b2 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -30,8 +30,9 @@ VocabParallelEmbedding, pad_vocab_size) from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group -from vllm_ascend.utils import lmhead_tp_enable +from vllm_ascend.distributed.parallel_state import (get_embed_tp_group, + get_lmhead_tp_group) +from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable class AscendVocabParallelEmbedding(VocabParallelEmbedding): @@ -50,9 +51,12 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): nn.Module.__init__(self) - - if lmhead_tp_enable() and prefix.find("head") != -1: + self.forward_type = None + if lmhead_tp_enable() and "head" in prefix: self.comm_group = get_lmhead_tp_group() + elif embedding_tp_enable() and "embed_tokens" in prefix: + self.comm_group = get_embed_tp_group() + self.forward_type = "embed_tp" else: self.comm_group = get_tp_group() @@ -146,6 +150,28 @@ def _get_masked_input_and_mask( return input_, ~vocab_mask def forward(self, input_): + if self.forward_type == "embed_tp": + return self._forward_embed_tp(input_) + else: + return self._forward_origin(input_) + + def _forward_embed_tp(self, input_): + complete_input = self.comm_group.all_gather(input_, dim=0) + masked_input, input_mask = self._get_masked_input_and_mask( + complete_input, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, + masked_input.long()) + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + output = self.comm_group.reduce_scatter(output_parallel, dim=0) + output = output.view(input_.shape[0], -1) + return output + + def _forward_origin(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = self._get_masked_input_and_mask( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 405fd68567a..2435d87eb0c 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -718,15 +718,20 @@ def get_ascend_device_type(): def lmhead_tp_enable() -> bool: - return get_ascend_config().lmhead_tensor_parallel_size is not None + return get_ascend_config().module_tp_config.lmhead_tensor_parallel_size > 0 + + +def embedding_tp_enable() -> bool: + return get_ascend_config( + ).module_tp_config.embedding_tensor_parallel_size > 0 def oproj_tp_enable() -> bool: - return get_ascend_config().oproj_tensor_parallel_size is not None + return get_ascend_config().module_tp_config.oproj_tensor_parallel_size > 0 def mlp_tp_enable() -> bool: - return envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE + return get_ascend_config().module_tp_config.mlp_tensor_parallel_size > 0 def matmul_allreduce_enable() -> bool: From c8968019cb3bbc3237247d74eddb71d6fae29275 Mon Sep 17 00:00:00 2001 From: zzhxx Date: Thu, 11 Dec 2025 14:37:25 +0800 Subject: [PATCH 2/4] fix bug Signed-off-by: zzhxx --- vllm_ascend/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 2435d87eb0c..ec74b909a92 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -979,7 +979,7 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config): logger.warning_once( "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." ) - if ascend_config.oproj_tensor_parallel_size is not None: + if ascend_config.module_tp_config.oproj_tensor_parallel_size is not None: raise AssertionError( "flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size" ) From 88c084d439055e5195029d0fd3e3f16d38bf19d6 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Thu, 11 Dec 2025 17:55:21 +0800 Subject: [PATCH 3/4] fix bug Signed-off-by: zzhx1 --- vllm_ascend/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ec74b909a92..f49dc191a13 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -979,7 +979,7 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config): logger.warning_once( "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." ) - if ascend_config.module_tp_config.oproj_tensor_parallel_size is not None: + if ascend_config.module_tp_config.oproj_tensor_parallel_size > 0: raise AssertionError( "flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size" ) From a33092778d2bc9fe8070957270ab639e89a7ee77 Mon Sep 17 00:00:00 2001 From: zzhx1 Date: Thu, 11 Dec 2025 20:23:11 +0800 Subject: [PATCH 4/4] Replace name as FinegrainedTPConfig Signed-off-by: zzhx1 --- .../configuration/additional_config.md | 6 ++--- tests/ut/distributed/test_parallel_state.py | 6 ++--- tests/ut/ops/test_linear.py | 10 ++++----- tests/ut/ops/test_vocab_parallel_embedding.py | 4 ++-- vllm_ascend/ascend_config.py | 22 ++++++++++--------- vllm_ascend/distributed/parallel_state.py | 9 ++++---- vllm_ascend/utils.py | 13 ++++++----- 7 files changed, 38 insertions(+), 32 deletions(-) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 7d17df8b432..5163c102bec 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -27,7 +27,7 @@ The following table lists additional configuration options available in vLLM Asc | Name | Type | Default | Description | |-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------| | `xlite_graph_config` | dict | `{}` | Configuration options for xlite graph mode | -| `module_tp_config` | dict | `{}` | Configuration options for module tensor parallelism | +| `finegrained_tp_config` | dict | `{}` | Configuration options for module tensor parallelism | | `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch | | `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for an MoE model, an expert map path needs to be passed in. | @@ -57,7 +57,7 @@ The details of each configuration option are as follows: | `enabled` | bool | `False` | Whether to enable weight prefetch. | | `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weight. | -**module_tp_config** +**finegrained_tp_config** | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | @@ -84,7 +84,7 @@ An example of additional configuration is as follows: } }, }, - "module_tp_config": { + "finegrained_tp_config": { "lmhead_tensor_parallel_size": 8, "oproj_tensor_parallel_size": 8, "embedding_tensor_parallel_size": 8, diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 4c0300a8550..7fd22629943 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -38,9 +38,9 @@ def mock_distributed(): def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_ascend_config = MagicMock() - mock_ascend_config.module_tp_config.lmhead_tensor_parallel_size = 2 - mock_ascend_config.module_tp_config.oproj_tensor_parallel_size = 2 - mock_ascend_config.module_tp_config.embedding_tensor_parallel_size = 2 + mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2 + mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2 + mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2 mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2 mock_ascend_config.pd_tp_ratio = 2 mock_ascend_config.num_head_replica = 0 diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 8ac2e4719ad..995a69f4751 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -25,8 +25,8 @@ def setUp(self): parallel_state._OTP = self.mock_group self.mock_ascend_config = MagicMock() - self.mock_ascend_config.module_tp_config.oproj_tensor_parallel_size = 2 - self.mock_ascend_config.module_tp_config.mlp_tensor_parallel_size = 2 + self.mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2 + self.mock_ascend_config.finegrained_tp_config.mlp_tensor_parallel_size = 2 self.patches = [ patch("vllm_ascend.ascend_config.get_ascend_config", @@ -84,7 +84,7 @@ def test_mlp_optimize(self): ascend_config._ASCEND_CONFIG = MagicMock() ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False - ascend_config._ASCEND_CONFIG.module_tp_config.mlp_tensor_parallel_size = 2 + ascend_config._ASCEND_CONFIG.finegrained_tp_config.mlp_tensor_parallel_size = 2 ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False linear = AscendRowParallelLinear( @@ -103,7 +103,7 @@ def test_oproj_tp(self): ascend_config._ASCEND_CONFIG = MagicMock() ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False - ascend_config._ASCEND_CONFIG.module_tp_config.oproj_tensor_parallel_size = 2 + ascend_config._ASCEND_CONFIG.finegrained_tp_config.oproj_tensor_parallel_size = 2 ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False linear = AscendRowParallelLinear( @@ -123,7 +123,7 @@ def test_merged_mlp_tp_init(self): ascend_config._ASCEND_CONFIG = MagicMock() ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False - ascend_config._ASCEND_CONFIG.module_tp_config.mlp_tensor_parallel_size = 2 + ascend_config._ASCEND_CONFIG.finegrained_tp_config.mlp_tensor_parallel_size = 2 ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False linear = AscendMergedColumnParallelLinear( diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index 15963ba391f..700da540f32 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -44,8 +44,8 @@ def setUp(self): mock_vllm_config = MagicMock() mock_vllm_config.additional_config = {} self.mock_ascend_config = MagicMock() - self.mock_ascend_config.module_tp_config.lmhead_tensor_parallel_size = 2 - self.mock_ascend_config.module_tp_config.embedding_tensor_parallel_size = 2 + self.mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2 + self.mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2 self.patches = [ patch("vllm_ascend.utils.get_ascend_config", diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 4194433cca2..1f87f4807be 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -67,8 +67,10 @@ def __init__(self, vllm_config): self.ascend_compilation_config = AscendCompilationConfig( **ascend_compilation_config) - module_tp_config = additional_config.get("module_tp_config", {}) - self.module_tp_config = ModuleTPConfig(module_tp_config, vllm_config) + finegrained_tp_config = additional_config.get("finegrained_tp_config", + {}) + self.finegrained_tp_config = FinegrainedTPConfig( + finegrained_tp_config, vllm_config) # Dump / PrecisionDebugger configuration dump_config_path = additional_config.get("dump_config", None) @@ -156,19 +158,19 @@ def __init__(self, vllm_config): kv_cfg._engine_id_patched = True -class ModuleTPConfig: +class FinegrainedTPConfig: """ - Configuration Object for module_tp_config from additional_config + Configuration Object for finegrained_tp_config from additional_config """ - def __init__(self, module_tp_config: dict, vllm_config): - self.oproj_tensor_parallel_size = module_tp_config.get( + def __init__(self, finegrained_tp_config: dict, vllm_config): + self.oproj_tensor_parallel_size = finegrained_tp_config.get( "oproj_tensor_parallel_size", 0) - self.lmhead_tensor_parallel_size = module_tp_config.get( + self.lmhead_tensor_parallel_size = finegrained_tp_config.get( "lmhead_tensor_parallel_size", 0) - self.embedding_tensor_parallel_size = module_tp_config.get( + self.embedding_tensor_parallel_size = finegrained_tp_config.get( "embedding_tensor_parallel_size", 0) - self.mlp_tensor_parallel_size = module_tp_config.get( + self.mlp_tensor_parallel_size = finegrained_tp_config.get( "mlp_tensor_parallel_size", 0) enabled_configs = [] @@ -208,7 +210,7 @@ def __init__(self, module_tp_config: dict, vllm_config): "module tp sizes must divide data_parallel_size") if any(size > 0 for size in module_tp_sizes) and enabled_configs: logger.info( - f"module_tp_config enabled: {', '.join(enabled_configs)}") + f"finegrained_tp_config enabled: {', '.join(enabled_configs)}") class AscendCompilationConfig: diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 3eff12b25a1..7af091b21fc 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -141,13 +141,14 @@ def _create_or_get_group(group_size: int, return _group_cache[group_size] - otp_size = get_ascend_config().module_tp_config.oproj_tensor_parallel_size + otp_size = get_ascend_config( + ).finegrained_tp_config.oproj_tensor_parallel_size lmhead_tp_size = get_ascend_config( - ).module_tp_config.lmhead_tensor_parallel_size + ).finegrained_tp_config.lmhead_tensor_parallel_size embedding_tp_size = get_ascend_config( - ).module_tp_config.embedding_tensor_parallel_size + ).finegrained_tp_config.embedding_tensor_parallel_size mlp_tp_size = get_ascend_config( - ).module_tp_config.embedding_tensor_parallel_size + ).finegrained_tp_config.embedding_tensor_parallel_size global _OTP, _LMTP, _EMBED_TP diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 93fd8ebf1ab..3c5d5b5531d 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -705,20 +705,23 @@ def get_ascend_device_type(): def lmhead_tp_enable() -> bool: - return get_ascend_config().module_tp_config.lmhead_tensor_parallel_size > 0 + return get_ascend_config( + ).finegrained_tp_config.lmhead_tensor_parallel_size > 0 def embedding_tp_enable() -> bool: return get_ascend_config( - ).module_tp_config.embedding_tensor_parallel_size > 0 + ).finegrained_tp_config.embedding_tensor_parallel_size > 0 def oproj_tp_enable() -> bool: - return get_ascend_config().module_tp_config.oproj_tensor_parallel_size > 0 + return get_ascend_config( + ).finegrained_tp_config.oproj_tensor_parallel_size > 0 def mlp_tp_enable() -> bool: - return get_ascend_config().module_tp_config.mlp_tensor_parallel_size > 0 + return get_ascend_config( + ).finegrained_tp_config.mlp_tensor_parallel_size > 0 def matmul_allreduce_enable() -> bool: @@ -966,7 +969,7 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config): logger.warning_once( "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." ) - if ascend_config.module_tp_config.oproj_tensor_parallel_size > 0: + if ascend_config.finegrained_tp_config.oproj_tensor_parallel_size > 0: raise AssertionError( "flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size" )