diff --git a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py index 6a204dc0024..48777894774 100644 --- a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py +++ b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py @@ -47,8 +47,8 @@ def test_generate_task_and_state_flow(mock_adaptor): loader_obj.state = loader.ExpertWeightUpdateState.WAITING loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0) - assert loader_obj.comm_op_list is None - assert loader_obj.state == loader.ExpertWeightUpdateState.WAITING + assert not loader_obj.comm_op_list + assert loader_obj.state == loader.ExpertWeightUpdateState.READY def test_asyn_transfer_and_update(mock_adaptor): diff --git a/tests/ut/ops/test_expert_load_balancer.py b/tests/ut/ops/test_expert_load_balancer.py deleted file mode 100644 index f7f68472131..00000000000 --- a/tests/ut/ops/test_expert_load_balancer.py +++ /dev/null @@ -1,140 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import json -import os -from typing import List, TypedDict -from unittest import mock - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer - - -class Device(TypedDict): - device_id: int - device_expert: List[int] - - -class Layer(TypedDict): - layer_id: int - device_count: int - device_list: List[Device] - - -class MockData(TypedDict): - moe_layer_count: int - layer_list: List[Layer] - - -class TestExpertLoadBalancer(TestBase): - - def setUp(self): - _TEST_DIR = os.path.dirname(__file__) - json_file = _TEST_DIR + "/expert_map.json" - with open(json_file, 'r') as f: - self.expert_map: MockData = json.load(f) - - self.expert_load_balancer = ExpertLoadBalancer(json_file, 8) - - def test_init(self): - - self.assertIsInstance(self.expert_load_balancer.expert_map_tensor, - torch.Tensor) - self.assertEqual(self.expert_load_balancer.layers_num, - self.expert_map["moe_layer_count"]) - self.assertEqual(self.expert_load_balancer.ranks_num, - self.expert_map["layer_list"][0]["device_count"]) - - def test_generate_index_dicts(self): - tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]]) - result = self.expert_load_balancer.generate_index_dicts(tensor_2d) - expected_result = [{ - 7: 0, - 2: 1, - 0: 2, - 3: 3, - 5: 4 - }, { - 6: 5, - 1: 6, - 4: 7, - 7: 8, - 2: 9 - }] - self.assertEqual(result, expected_result) - - def test_generate_expert_placement_map(self): - expert_placement_map = self.expert_load_balancer.generate_expert_placement_map( - ) - self.assertEqual(expert_placement_map.shape, - (self.expert_load_balancer.layers_num, - self.expert_load_balancer.ranks_num, 10)) - self.assertTrue(torch.all(expert_placement_map >= -1)) - - def test_generate_log2phy_expert_map(self): - layer_id = 0 - log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map( - layer_id) - self.assertEqual(log2phy_map.shape, - (self.expert_load_balancer.ranks_num, 10)) - self.assertTrue(torch.all(log2phy_map >= -1)) - - @mock.patch("torch_npu.npu._lazy_init") - @mock.patch("torch.npu.current_device", return_value="cpu") - def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init): - layer_id = 0 - rank_id = 0 - rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( - layer_id, rank_id) - self.assertEqual(rank_local_expert_num, 5) - expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0, -1, -1], - dtype=torch.int32).to( - rank_expert_map.device) - self.assertTrue(rank_expert_map.equal(expected_tensor)) - - rank_id = 1 - rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( - layer_id, rank_id) - expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3, -1, -1], - dtype=torch.int32).to( - rank_expert_map.device) - self.assertTrue(rank_expert_map.equal(expected_tensor)) - - def test_get_rank_log2phy_map(self): - layer_id = 0 - rank_id = 0 - log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( - layer_id, rank_id) - expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0, -1, -1], - dtype=torch.int32).to( - log2phy_map.device) - self.assertTrue(log2phy_map.equal(expected_tensor)) - - rank_id = 1 - log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( - layer_id, rank_id) - expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8, -1, -1], - dtype=torch.int32).to( - log2phy_map.device) - self.assertTrue(log2phy_map.equal(expected_tensor)) - - def test_get_global_redundant_expert_num(self): - redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num( - ) - expected_redundant_expert_num = len(self.expert_map["layer_list"][0]["device_list"][0]["device_expert"]) * \ - self.expert_map["layer_list"][0]["device_count"] - 8 - self.assertEqual(redundant_expert_num, expected_redundant_expert_num) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 7620999a159..5b4595559d4 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -24,7 +24,7 @@ def setUp(self): self.moe_config.tp_size = 1 self.moe_config.ep_size = 1 self.moe_config.dp_group = MagicMock() - self.moe_config.num_global_redundant_experts = 0 + self.moe_config.global_redundant_expert_num = 0 @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 140bae5cd20..698022ec7e2 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -144,48 +144,11 @@ def test_get_combine_mc_kwargs_with_quant(self): self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True - + self.dispatcher.moe_expert_num = len(expert_map) kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states, context_metadata) self.assertIn("tp_send_counts", kwargs) - def test_token_combine_with_shared_experts(self): - shared_experts = MagicMock() - shared_experts.down_proj.return_value = (torch.randn(10, 128), - torch.tensor(1.0)) - - topk_ids = torch.randint(0, 8, (10, 1)) - topk_weights = torch.randn(10, 1) - expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - assist_info_for_combine = torch.arange(10) - tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "mc2_mask": None, - "assist_info_for_combine": assist_info_for_combine, - "expand_scales": None, - "shared_experts": shared_experts, - "shared_act": torch.randn(10, 128), - "swiglu_out_scale": torch.randn(10, 1), - "tp_recv_counts": tp_recv_counts - } - - self.dispatcher.with_quant = True - self.dispatcher.need_extra_args = True - self.dispatcher.enable_dispatch_v2 = True - - hidden_states = torch.randn(10, 128) - with patch("torch_npu.npu_moe_distribute_combine_v2", - return_value=torch.randn(10, 128)): - result = self.dispatcher.token_combine(hidden_states, - context_metadata) - self.assertIsInstance(result, tuple) - class TestTokenDispatcherWithAllGather(TestBase): diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index 5c676cddb8f..ce1c3d73325 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -50,10 +50,6 @@ def generate_expert_d2d_transfer_task(self, expert_send_info, ) return - # If neither send nor receive task is needed for this layer on this rank, return - if not (expert_send_info or expert_recv_info): - return - self.updated_expert_map = updated_expert_map self.layer_id = layer_id diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py deleted file mode 100644 index 7e8a9aefd28..00000000000 --- a/vllm_ascend/ops/expert_load_balancer.py +++ /dev/null @@ -1,118 +0,0 @@ -import json -import random -from typing import Dict, List - -import torch -import torch.distributed as dist - - -class ExpertLoadBalancer(object): - - def __init__(self, expert_map_path, num_experts): - self.expert_map_path = expert_map_path - self.num_experts = num_experts - self.tensor_data = [] - self.expert_map_tensor, self.layers_num, self.ranks_num = ( - self._expert_file_to_tensor()) - self.global_expert_num = num_experts + self.get_global_redundant_expert_num( - ) - self.expert_placement_map = self.generate_expert_placement_map() - - def _expert_file_to_tensor(self): - with open(self.expert_map_path, "r") as f: - data = json.load(f) - layers_num = data["moe_layer_count"] - gpus_num = data["layer_list"][0]["device_count"] - for layer in data["layer_list"]: - device_data = [] - for device in layer["device_list"]: - device_data.append(device["device_expert"]) - self.tensor_data.append(device_data) - expert_map_tensor = torch.tensor(self.tensor_data, dtype=torch.int32) - return expert_map_tensor, layers_num, gpus_num - - def generate_index_dicts(self, tensor_2d): - dict_list = [] - current_idx = 0 - - for row in tensor_2d: - value_to_index = {} - for i in range(row.size(0)): - value = row[i].item() - value_to_index[value] = current_idx + i - dict_list.append(value_to_index) - current_idx += row.size(0) - - return dict_list - - def generate_expert_placement_map(self): - expert_placement_map = torch.full( - (self.layers_num, self.ranks_num, self.global_expert_num), - -1, - dtype=torch.int32, - ) - for layer_id in range(self.layers_num): - for gpu_id in range(self.ranks_num): - e_ids = self.expert_map_tensor[layer_id, gpu_id] - expert_placement_map[layer_id, gpu_id, - e_ids] = torch.arange(len(e_ids), - dtype=torch.int32) - return expert_placement_map - - def generate_log2phy_expert_map(self, layer_id): - concatenated = torch.flatten(self.expert_map_tensor[layer_id]) - rank_expert_to_global = self.generate_index_dicts( - self.expert_map_tensor[layer_id]) - result_dict: Dict[int, List[int]] = {} - for idx, value in enumerate(concatenated): - key = value.item() - if key not in result_dict: - result_dict[key] = [] - result_dict[key].append(idx) - - log2phy_map = torch.full((self.ranks_num, self.global_expert_num), - -1, - dtype=torch.int32) - for rank in range(self.ranks_num): - for key in result_dict: - indices_in_concat = result_dict[key] - if key in rank_expert_to_global[rank]: - log2phy_map[rank][key] = rank_expert_to_global[rank][key] - else: - chosen_index = random.choice(indices_in_concat) - log2phy_map[rank][key] = chosen_index - return log2phy_map - - def get_rank_placement_map(self, layer_id, rank_id): - layer_expert_map = self.expert_placement_map[layer_id] - rank_expert_map = layer_expert_map[rank_id].to( - torch.npu.current_device()) - rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item() - return rank_local_expert_num, rank_expert_map - - def get_rank_log2phy_map(self, layer_id, rank_id): - layer_log2phy_map = self.generate_log2phy_expert_map(layer_id) - return layer_log2phy_map[rank_id] - - def get_global_redundant_expert_num(self): - global_redundant_expert_num = ( - len(self.expert_map_tensor[0][0]) * self.ranks_num - - self.num_experts) - return global_redundant_expert_num - - def check_expert_map_tensor(self): - if dist.is_initialized(): - try: - rank = dist.get_rank() - world_size = dist.get_world_size() - all_expert_maps = [None for _ in range(world_size)] - dist.all_gather_object(all_expert_maps, self.tensor_data) - for rank_id, expert_map_tensor in enumerate(all_expert_maps): - if self.tensor_data != expert_map_tensor: - raise ValueError( - f"The expert map of rank{rank} is not equal to rank{rank_id}" - ) - return True - except Exception as e: - raise ValueError( - f"The expert maps of all ranks are inconsistency: {e}") diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index a9547a5a0e1..1b1caa127fe 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -234,7 +234,7 @@ def __init__(self, *args, **kwargs): self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts - self.moe_config.original_num_experts = num_experts + self.moe_config.global_redundant_expert_num = self.global_redundant_expert_num moe_quant_params = { "num_experts": local_num_experts, diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 30d1e5c1376..c0804106ee3 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -105,7 +105,6 @@ def fused_experts( dynamic_scale_for_share: Optional[Any] = None, # For load balance log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, need_trans: bool = False, dynamic_eplb: bool = False, mc2_mask: torch.Tensor = None, @@ -124,7 +123,8 @@ def fused_experts( topk_ids=topk_ids, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, + global_redundant_expert_num=self.moe_config. + global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, @@ -283,7 +283,6 @@ def fused_experts( dynamic_scale_for_share: Optional[Any] = None, # For load balance log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, need_trans: bool = False, dynamic_eplb: bool = False, mc2_mask: torch.Tensor = None, diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index aeb751d0d8d..93cd51091b8 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -136,18 +136,14 @@ def get_dispatch_mc2_kwargs( mc2_mask: torch.Tensor, global_redundant_expert_num: int = 0, ): - if self.with_quant: - quant_mode = 2 - moe_expert_num = len(expert_map) - else: - quant_mode = 0 - moe_expert_num = len(expert_map) + quant_mode = 2 if self.with_quant else 0 + self.moe_expert_num = len(expert_map) + global_redundant_expert_num kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, + "moe_expert_num": self.moe_expert_num, "global_bs": self.global_bs, "expert_token_nums_type": 0, } @@ -259,7 +255,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, expand_scales = context_metadata["expand_scales"] assert expert_map is not None - moe_expert_num = len(expert_map) kwargs_mc2 = { "expand_x": hidden_states, @@ -267,7 +262,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, "expert_scales": topk_weights.to(torch.float32), "expert_shard_type": 0, "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, + "moe_expert_num": self.moe_expert_num, "global_bs": self.global_bs, } @@ -369,7 +364,7 @@ def token_dispatch(self, hidden_states = hidden_states * \ topk_weights.to(hidden_states.dtype) if expert_map is not None: - global_num_experts = len(expert_map) + global_num_experts = len(expert_map) + global_redundant_expert_num mask = (expert_map[topk_ids] != -1) topk_weights = topk_weights * mask first_expert_idx = get_ep_group( diff --git a/vllm_ascend/quantization/w4a16.py b/vllm_ascend/quantization/w4a16.py index d15fa25aaa2..28f2ed6406b 100644 --- a/vllm_ascend/quantization/w4a16.py +++ b/vllm_ascend/quantization/w4a16.py @@ -242,7 +242,6 @@ def apply( use_int4_w4a16=True, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 45a7bc18337..dfe831b6f5f 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -390,7 +390,6 @@ def apply( use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 1c158d09edb..b7074cf0bea 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -264,7 +264,6 @@ def apply( use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share,