diff --git a/lightllm/common/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/fused_moe/grouped_fused_moe_ep.py index 3b5cc6b91..665c3e324 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/fused_moe/grouped_fused_moe_ep.py @@ -5,6 +5,7 @@ import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple import torch.distributed as dist +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd @@ -142,6 +143,15 @@ def fused_experts_impl( # scatter all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + + if get_env_start_args().enable_ep_fake_balance: + rank = dist.get_rank() + if rank == 0: + logger.info( + f"prefill, [{rank}], all_tokens = {all_tokens}, " + f"num_recv_tokens_per_expert_list: {num_recv_tokens_per_expert_list}" + ) + # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: @@ -219,6 +229,18 @@ def fused_experts_impl( async_finish=False, return_recv_hook=False, ) + + # NOTE: when decoding graph is open, we can not call logger. Thus it can only be used when --disable_cudagraph + args = get_env_start_args() + if args.enable_ep_fake_balance and args.disable_cudagraph: + rank = dist.get_rank() + all_tokens = sum(masked_m) + if rank == 0: + logger.info( + f"decode, [{rank}], all_tokens = {all_tokens}, " + f"expected_m = {expected_m}, num_recv_tokens_per_expert: {masked_m}" + ) + # deepgemm gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m) # low latency combine diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index ca8d22f48..3f639aa4f 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -21,6 +21,8 @@ import torch from lightllm.utils.sgl_utils import sgl_ops from lightllm.utils.light_utils import light_ops +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.balance_utils import BalancedTensor from typing import Callable, List, Optional, Tuple from lightllm.common.fused_moe.softmax_topk import softmax_topk @@ -227,4 +229,11 @@ def select_experts( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize ) + # Enable EP fake balance + if get_env_start_args().enable_ep_fake_balance: + num_tokens, num_experts = router_logits.shape + balanced_tensor_collection = BalancedTensor(num_experts=num_experts, num_selected=top_k) + balance_topk_ids = balanced_tensor_collection.get_balance_topk_ids(num_tokens) + topk_ids.copy_(balance_topk_ids) + return topk_weights, topk_ids diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3f3eaf96f..4d48af10f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -333,6 +333,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) + + parser.add_argument("--enable_ep_fake_balance", action="store_true", help="Enable the fake balance of the EP mode") + parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") parser.add_argument( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ec1eb427e..d7a5208e4 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -76,6 +76,7 @@ class StartArgs: visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) enable_monitor_auth: bool = field(default=False) + enable_ep_fake_balance: bool = field(default=False) disable_cudagraph: bool = field(default=False) graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) diff --git a/lightllm/utils/balance_utils.py b/lightllm/utils/balance_utils.py new file mode 100755 index 000000000..556c0cdff --- /dev/null +++ b/lightllm/utils/balance_utils.py @@ -0,0 +1,67 @@ +import torch +import os + +import threading + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def singleton_threadsafe(cls): + instances = {} + lock = threading.Lock() + + def get_instance(*args, **kwargs): + # A key that includes the arguments is needed for parameter-dependent singletons. + # Using a tuple of args and a frozenset of kwargs items makes it hashable. + key = (cls, args, frozenset(kwargs.items())) + with lock: + if key not in instances: + instances[key] = cls(*args, **kwargs) + return instances[key] + + return get_instance + + +@singleton_threadsafe +class BalancedTensor: + def __init__(self, num_experts=256, num_selected=8): + self.balanced_tensors = {} + self.num_experts = num_experts + self.num_selected = num_selected + + def generate_balanced_tensor(self, num_tokens): + # Evenly distribute num_tokens to num_selected experts out of num_experts. + # Note that the num_selected experts activated by a token cannot be repeated. + tensor = torch.empty((num_tokens, self.num_selected), dtype=torch.int, device="cuda") + expert_load = torch.zeros(self.num_experts, dtype=torch.int, device="cuda") + + for i in range(num_tokens): + selected_mask = torch.zeros(self.num_experts, dtype=torch.bool, device="cuda") + for j in range(self.num_selected): + # Use a large value for already selected experts to exclude them + load_view = torch.where(selected_mask, torch.iinfo(expert_load.dtype).max, expert_load) + + min_load_indices = torch.where(load_view == load_view.min())[0] + + if len(min_load_indices) > 1: + # If there are multiple least-loaded experts, select one randomly + rand_idx = torch.randint(0, len(min_load_indices), (1,), device="cuda").item() + chosen_expert = min_load_indices[rand_idx] + else: + chosen_expert = min_load_indices[0] + + tensor[i, j] = chosen_expert + expert_load[chosen_expert] += 1 + selected_mask[chosen_expert] = True + + return tensor + + def get_balance_topk_ids(self, num_tokens): + if num_tokens in self.balanced_tensors: + return self.balanced_tensors[num_tokens] + + tensor = self.generate_balanced_tensor(num_tokens) + self.balanced_tensors[num_tokens] = tensor + return tensor