From 88ca5a0023c0afb92e953cdb3ed287a810245bd4 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 3 Jul 2025 21:24:11 +0800 Subject: [PATCH 1/3] add ep fake balance --- .../common/fused_moe/grouped_fused_moe_ep.py | 12 ++++ lightllm/common/fused_moe/topk_select.py | 8 +++ lightllm/utils/balance_utils.py | 63 +++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100755 lightllm/utils/balance_utils.py diff --git a/lightllm/common/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/fused_moe/grouped_fused_moe_ep.py index 3b5cc6b91..2f6120881 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/fused_moe/grouped_fused_moe_ep.py @@ -142,6 +142,11 @@ def fused_experts_impl( # scatter all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + + # 用于调试负载平衡的重要日志 + #rank=dist.get_rank() + #logger.info(f"prefill, [{rank}], all_tokens = {all_tokens}, num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}") + # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: @@ -219,6 +224,13 @@ def fused_experts_impl( async_finish=False, return_recv_hook=False, ) + + # 用于调试负载平衡的重要日志 + # when decoding graph is open, we can not call logger. --profile can close cuda graph + #rank=dist.get_rank() + #all_tokens = sum(masked_m) + #logger.info(f"decode, [{rank}], all_tokens = {all_tokens}, expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}") + # deepgemm gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m) # low latency combine diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index ca8d22f48..11c5256f3 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -21,6 +21,7 @@ import torch from lightllm.utils.sgl_utils import sgl_ops from lightllm.utils.light_utils import light_ops +from lightllm.utils.balance_utils import BalancedTensor from typing import Callable, List, Optional, Tuple from lightllm.common.fused_moe.softmax_topk import softmax_topk @@ -227,4 +228,11 @@ def select_experts( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize ) + # EP fake负载平衡开关 + if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true": + M, _ = hidden_states.shape + balanced_tensor_collection = BalancedTensor() + balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M) + topk_ids.copy_(balance_topk_ids) + return topk_weights, topk_ids diff --git a/lightllm/utils/balance_utils.py b/lightllm/utils/balance_utils.py new file mode 100755 index 000000000..c4ac3d9fb --- /dev/null +++ b/lightllm/utils/balance_utils.py @@ -0,0 +1,63 @@ +import torch +import os + +import threading + +def singleton_threadsafe(cls): + instances = {} + lock = threading.Lock() + + def get_instance(*args, **kwargs): + with lock: + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + return get_instance + +@singleton_threadsafe +class BalancedTensor: + def __init__(self, num_experts=256, num_selected=8): + self.balanced_tensors = {} + self.num_experts = num_experts + self.num_selected = num_selected + + def generate_balanced_tensor(self, length): + # 初始化一个 length * 8 的全零张量,放置在 GPU 上 + tensor = torch.zeros((length, self.num_selected), dtype=torch.int, device='cuda') + # 初始化每个专家的负载计数 + expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda') + + for i in range(length): + available_experts = torch.arange(self.num_experts, device='cuda') + selected = [] + for _ in range(self.num_selected): + # 计算每个可用专家的当前负载 + current_load = expert_load[available_experts] + # 选择负载最小的专家 + min_load_indices = torch.where(current_load == current_load.min())[0] + if len(min_load_indices) > 1: + # 如果有多个负载最小的专家,随机选择一个 + chosen_index = torch.randint(0, len(min_load_indices), (1,), device='cuda').item() + chosen_expert_index = min_load_indices[chosen_index] + else: + chosen_expert_index = min_load_indices[0] + chosen_expert = available_experts[chosen_expert_index] + selected.append(chosen_expert) + # 从可用专家列表中移除已选择的专家 + available_experts = torch.cat( + [available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1:]]) + # 更新该专家的负载 + expert_load[chosen_expert] += 1 + tensor[i] = torch.tensor(selected, dtype=torch.int, device='cuda') + return tensor + + def get_balance_topk_ids(self, length): + if self.balanced_tensors.get(length) is not None: + #print("find length ", length) + return self.balanced_tensors[length] + else: + #print("generate length ", length) + tensor = self.generate_balanced_tensor(length) + self.balanced_tensors[length] = tensor + return tensor + From 66c0f2deb9b38544421c15938545d59e9a2f1d10 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 7 Jul 2025 20:34:12 +0800 Subject: [PATCH 2/3] fix pre-commit checks and gemini checks: more robust balance management, from env control to option control, better logger info control, better format --- .../common/fused_moe/grouped_fused_moe_ep.py | 26 +++++--- lightllm/common/fused_moe/topk_select.py | 11 ++-- lightllm/server/api_cli.py | 3 + lightllm/server/core/objs/start_args_type.py | 1 + lightllm/utils/balance_utils.py | 62 +++++++++++-------- 5 files changed, 63 insertions(+), 40 deletions(-) diff --git a/lightllm/common/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/fused_moe/grouped_fused_moe_ep.py index 2f6120881..665c3e324 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/fused_moe/grouped_fused_moe_ep.py @@ -5,6 +5,7 @@ import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple import torch.distributed as dist +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd @@ -143,9 +144,13 @@ def fused_experts_impl( # scatter all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. - # 用于调试负载平衡的重要日志 - #rank=dist.get_rank() - #logger.info(f"prefill, [{rank}], all_tokens = {all_tokens}, num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}") + if get_env_start_args().enable_ep_fake_balance: + rank = dist.get_rank() + if rank == 0: + logger.info( + f"prefill, [{rank}], all_tokens = {all_tokens}, " + f"num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}" + ) # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) @@ -225,11 +230,16 @@ def fused_experts_impl( return_recv_hook=False, ) - # 用于调试负载平衡的重要日志 - # when decoding graph is open, we can not call logger. --profile can close cuda graph - #rank=dist.get_rank() - #all_tokens = sum(masked_m) - #logger.info(f"decode, [{rank}], all_tokens = {all_tokens}, expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}") + # NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph + args = get_env_start_args() + if args.enable_ep_fake_balance and args.disable_cudagraph: + rank = dist.get_rank() + all_tokens = sum(masked_m) + if rank == 0: + logger.info( + f"decode, [{rank}], all_tokens = {all_tokens}, " + f"expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}" + ) # deepgemm gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m) diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index 11c5256f3..3f639aa4f 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -21,6 +21,7 @@ import torch from lightllm.utils.sgl_utils import sgl_ops from lightllm.utils.light_utils import light_ops +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.balance_utils import BalancedTensor from typing import Callable, List, Optional, Tuple from lightllm.common.fused_moe.softmax_topk import softmax_topk @@ -228,11 +229,11 @@ def select_experts( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize ) - # EP fake负载平衡开关 - if os.environ.get("EP_FAKE_BALANCE_ENABLED") == "true": - M, _ = hidden_states.shape - balanced_tensor_collection = BalancedTensor() - balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(M) + # Enable EP fake balance + if get_env_start_args().enable_ep_fake_balance: + num_tokens, num_experts = router_logits.shape + balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k) + balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(num_tokens) topk_ids.copy_(balance_topk_ids) return topk_weights, topk_ids diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3f3eaf96f..4d48af10f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -333,6 +333,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) + + parser.add_argument("--enable_ep_fake_balance", action="store_true", help="Enable the fake balance of the EP mode") + parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") parser.add_argument( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ec1eb427e..d7a5208e4 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -76,6 +76,7 @@ class StartArgs: visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) enable_monitor_auth: bool = field(default=False) + enable_ep_fake_balance: bool = field(default=False) disable_cudagraph: bool = field(default=False) graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) diff --git a/lightllm/utils/balance_utils.py b/lightllm/utils/balance_utils.py index c4ac3d9fb..61b831d18 100755 --- a/lightllm/utils/balance_utils.py +++ b/lightllm/utils/balance_utils.py @@ -3,17 +3,27 @@ import threading +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + def singleton_threadsafe(cls): instances = {} lock = threading.Lock() def get_instance(*args, **kwargs): + # A key that includes the arguments is needed for parameter-dependent singletons. + # Using a tuple of args and a frozenset of kwargs items makes it hashable. + key = (cls, args, frozenset(kwargs.items())) with lock: - if cls not in instances: - instances[cls] = cls(*args, **kwargs) - return instances[cls] + if key not in instances: + instances[key] = cls(*args, **kwargs) + return instances[key] + return get_instance + @singleton_threadsafe class BalancedTensor: def __init__(self, num_experts=256, num_selected=8): @@ -21,43 +31,41 @@ def __init__(self, num_experts=256, num_selected=8): self.num_experts = num_experts self.num_selected = num_selected - def generate_balanced_tensor(self, length): - # 初始化一个 length * 8 的全零张量,放置在 GPU 上 - tensor = torch.zeros((length, self.num_selected), dtype=torch.int, device='cuda') - # 初始化每个专家的负载计数 - expert_load = torch.zeros(self.num_experts, dtype=torch.int, device='cuda') + def generate_balanced_tensor(self, num_tokens): + # Evenly distribute num_tokens to num_selected experts out of num_experts. + # Note that the num_selected experts activated by a token cannot be repeated. + # Performance is not that important, as it is only activated in special scenarios. + tensor = torch.zeros((num_tokens, self.num_selected), dtype=torch.int, device="cuda") + expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda") - for i in range(length): - available_experts = torch.arange(self.num_experts, device='cuda') + for i in range(num_tokens): + available_experts = torch.arange(self.num_experts, device="cuda") selected = [] for _ in range(self.num_selected): - # 计算每个可用专家的当前负载 current_load = expert_load[available_experts] - # 选择负载最小的专家 min_load_indices = torch.where(current_load == current_load.min())[0] if len(min_load_indices) > 1: - # 如果有多个负载最小的专家,随机选择一个 - chosen_index = torch.randint(0, len(min_load_indices), (1,), device='cuda').item() + # If there are multiple least-loaded experts, select one randomly + chosen_index = torch.randint(0, len(min_load_indices), (1,), device="cuda").item() chosen_expert_index = min_load_indices[chosen_index] else: chosen_expert_index = min_load_indices[0] chosen_expert = available_experts[chosen_expert_index] selected.append(chosen_expert) - # 从可用专家列表中移除已选择的专家 + # Remove the selected expert from the list of available experts available_experts = torch.cat( - [available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1:]]) - # 更新该专家的负载 + [available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1 :]] + ) expert_load[chosen_expert] += 1 - tensor[i] = torch.tensor(selected, dtype=torch.int, device='cuda') + + tensor[i] = torch.tensor(selected, dtype=torch.int, device="cuda") + return tensor - def get_balance_topk_ids(self, length): - if self.balanced_tensors.get(length) is not None: - #print("find length ", length) - return self.balanced_tensors[length] - else: - #print("generate length ", length) - tensor = self.generate_balanced_tensor(length) - self.balanced_tensors[length] = tensor - return tensor + def get_balance_topk_ids(self, num_tokens): + if num_tokens in self.balanced_tensors: + return self.balanced_tensors[num_tokens] + tensor = self.generate_balanced_tensor(num_tokens) + self.balanced_tensors[num_tokens] = tensor + return tensor From 11b6ade57161c157e217713f3a7bf921ea3757cb Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 17:07:53 +0800 Subject: [PATCH 3/3] use faster balance algo --- lightllm/utils/balance_utils.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/lightllm/utils/balance_utils.py b/lightllm/utils/balance_utils.py index 61b831d18..556c0cdff 100755 --- a/lightllm/utils/balance_utils.py +++ b/lightllm/utils/balance_utils.py @@ -34,31 +34,27 @@ def __init__(self, num_experts=256, num_selected=8): def generate_balanced_tensor(self, num_tokens): # Evenly distribute num_tokens to num_selected experts out of num_experts. # Note that the num_selected experts activated by a token cannot be repeated. - # Performance is not that important, as it is only activated in special scenarios. - tensor = torch.zeros((num_tokens, self.num_selected), dtype=torch.int, device="cuda") + tensor = torch.empty((num_tokens, self.num_selected), dtype=torch.int, device="cuda") expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda") for i in range(num_tokens): - available_experts = torch.arange(self.num_experts, device="cuda") - selected = [] - for _ in range(self.num_selected): - current_load = expert_load[available_experts] - min_load_indices = torch.where(current_load == current_load.min())[0] + selected_mask = torch.zeros(self.num_experts, dtype=torch.bool, device="cuda") + for j in range(self.num_selected): + # Use a large value for already selected experts to exclude them + load_view = torch.where(selected_mask, torch.iinfo(expert_load.dtype).max, expert_load) + + min_load_indices = torch.where(load_view == load_view.min())[0] + if len(min_load_indices) > 1: # If there are multiple least-loaded experts, select one randomly - chosen_index = torch.randint(0, len(min_load_indices), (1,), device="cuda").item() - chosen_expert_index = min_load_indices[chosen_index] + rand_idx = torch.randint(0, len(min_load_indices), (1,), device="cuda").item() + chosen_expert = min_load_indices[rand_idx] else: - chosen_expert_index = min_load_indices[0] - chosen_expert = available_experts[chosen_expert_index] - selected.append(chosen_expert) - # Remove the selected expert from the list of available experts - available_experts = torch.cat( - [available_experts[:chosen_expert_index], available_experts[chosen_expert_index + 1 :]] - ) - expert_load[chosen_expert] += 1 + chosen_expert = min_load_indices[0] - tensor[i] = torch.tensor(selected, dtype=torch.int, device="cuda") + tensor[i, j] = chosen_expert + expert_load[chosen_expert] += 1 + selected_mask[chosen_expert] = True return tensor