Skip to content

Elastic EP Support (Milestone 1 & 2)#8961

Closed
UNIDY2002 wants to merge 44 commits intosgl-project:mainfrom
HanHan009527:hank_mxa_dev
Closed

Elastic EP Support (Milestone 1 & 2)#8961
UNIDY2002 wants to merge 44 commits intosgl-project:mainfrom
HanHan009527:hank_mxa_dev

Conversation

@UNIDY2002
Copy link
Contributor

@UNIDY2002 UNIDY2002 commented Aug 8, 2025

SUPERSEDED BY #11657

Progress tracker:

Drafts:

Motivation

In the context of MoE models like DeepSeek V3, GPU resources must be dynamically expanded or reduced in response to the request rate. Additionally, the system must maintain correctness even when some GPUs fail, which is common in large-scale inference scenarios. Any changes or recovery in configuration should be completed within an acceptable time frame.

This PR aims to introduce elastic EP support for SGLang, as a part of #8210.

Roadmap

Milestone 1: fault tolerance in the case of masking out a certain rank
In this stage, we simulate a broken GPU by masking out a certain rank during inference. We expect most results (num_healthy_ranks out of num_total_ranks) to be correct. This milestone serves as a proof of concept for fault tolerance in the DeepSeek model, without requiring changes to the SGLang scheduler engine.

Milestone 2: fault tolerance in the case of killing some ranks
In this stage, we plan to enhance the SGLang's scheduler to make it aware of the underlying faults and dynamically redirect requests to healthy ranks. This requires adjustments to the P2P communication logic within the scheduler and other necessary modifications. This milestone will provide true fault tolerance.

Milestone 3: scaling up and scaling down
In this stage, we plan to introduce the ability to dynamically adjust GPU usage based on the system administrator's configuration, enabling seamless scaling of resources. This will include support for scaling the world size during inference, ensuring that the system remains efficient and responsive.

Modifications (Milestone 1)

Enhance the DeepEP Library to Support Fault Detection and Fault Tolerance

The current implementation of DeepEP assumes that all ranks are healthy and reachable. To handle faulty conditions, we modified the library to detect and manage faults. Specifically, we introduced a time limit for P2P communication during token dispatch, marking a peer as failed if a timeout occurs during the P2P receiving phase. This failure status is propagated to upper-layer modules to support EPLB's needs. To ensure consistency across all ranks, a master rank is elected to monitor the health of all peer ranks, which is critical for the EPLB module to gather statistical information of all the healthy ranks.

The modified implementation can be found at https://github.com/kvcache-ai/Mooncake/tree/sunxun/ep-dev/mooncake-ep. Enter the directory mooncake-ep and run python setup.py install. (Integration into the Mooncake's build workflow is on the way!) Set the environment variable SGLANG_USE_MXA_EP to 1 to enable this feature.

Extend EPLB to Redistribute Expert Weights Upon Fault Detection

When broken ranks are detected, their information is recorded in the global ExpertLocationMetadata with a new field. After each forwarding pass of the DeepseekV2ForCausalLM module, the EPLB manager checks for changes in broken ranks and triggers a rebalance if any new failures are detected.

To make EPLB work under faulty conditions, we made the following three modifications:

  • The EPLB algorithm is adjusted to only map logical experts to healthy ranks.
  • The original all-reduce across all EP ranks is replaced with an all-reduce among only the healthy ranks.
  • The original P2P expert weight swapping is replaced by local weight loading from the disk.

Mask Out a Certain Rank to Simulate a Broken GPU

We modify the forward method of DeepseekV2ForCausalLM by wrapping it in a conditional mask. After a predefined number of executions, the forward pass of the faulty rank will always output a zero tensor instead of running the module. This simulates a GPU failure during inference. The extension to handling the breakdown of multiple ranks is trivial, which will be addressed in Milestone 2.

To enable this behavior, set the environment variable SGLANG_EP_AVOID_RANK to the rank ID of the GPU to simulate as broken. Optionally, set SGLANG_AVOID_EP_TRIGGER_AT to specify the number of healthy executions before the breakdown occurs (default is 100).

How to Test

To test this implementation, launch the SGLang decoder backend with the following configurations:

  • Set the environment variable SGLANG_USE_MXA_EP to 1.
  • Set SGLANG_EP_AVOID_RANK to the rank ID of the simulated broken GPU.
  • Set dp_size to the same value as tp_size.
  • Set moe_dense_tp_size to 1.
  • Set ep_num_redundant_experts to a value that can cover the case of one broken rank.
  • Disable CUDA graph.

Then, run the bench.sh script from the benchmark/elastic_ep directory.

You should observe that num_healthy_ranks out of num_total_ranks requests produce successful results.

use test/srt/test_elastic_ep.py to check for accuracy.

  • In an environment with 4 H20 8 card machines, the reproduction code is as follows
GLOO_SOCKET_IFNAME=eth0 NCCL_IB_HCA=mlx5_ NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=eth0 NCCL_IB_GID_INDEX=3 \
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \
NCCL_MIN_NCHANNELS=24 \
NCCL_IB_QPS_PER_CONNECTION=8 \
SGL_ENABLE_JIT_DEEPGEMM=1 \
python3 -m sglang.launch_server \
--disaggregation-ib-device  "mlx5_1,mlx5_2,mlx5_3,mlx5_4" \
--model-path /data00/DeepSeek-R1-0528 \
--tp 16 --disaggregation-mode prefill  \
--host 0.0.0.0 --port 30300 --trust-remote-code --enable-deepep-moe --deepep-mode normal  --disable-radix-cache  --max-running-requests 16  --moe-dense-tp-size 1 --chunked-prefill-size 0 \
--trust-remote-code --watchdog-timeout 1000000  \
--enable-dp-attention --dp-size 16 --mem-fraction-static 0.8 \
--show-time-cost  --enable-dp-lm-head --page-size 64 \
--nnodes 2 --node-rank 0 --dist-init-addr  192.168.0.173:5050 


EXPORT_EXPERT_METADATA=1 EXPERT_METADATA_OUTPUT_DIR=/data00/expert_location_metadata SGLANG_AVOID_EP_TRIGGER_AT=100 SGLANG_EP_AVOID_RANK=0 SGLANG_USE_MXA_EP=1 GLOO_SOCKET_IFNAME=eth0 NCCL_IB_HCA=mlx5_ NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=eth0 NCCL_IB_GID_INDEX=3 \
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \
NCCL_MIN_NCHANNELS=24 \
NCCL_IB_QPS_PER_CONNECTION=8 \
SGL_ENABLE_JIT_DEEPGEMM=1 \
python3 -m sglang.launch_server \
--model-path /data00/DeepSeek-R1-0528 \
--tp 16 --disaggregation-mode decode  --disaggregation-ib-device  "mlx5_1,mlx5_2,mlx5_3,mlx5_4" \
--host 0.0.0.0 --port 30300 --trust-remote-code  --enable-deepep-moe --deepep-mode low_latency  --disable-radix-cache --mem-fraction-static 0.8 --max-running-requests 1024  --moe-dense-tp-size 1 --disable-cuda-graph  --watchdog-timeout 1000000 \
--context-length 6500 \
--trust-remote-code  --page-size 64 \
--show-time-cost \
 --disable-cuda-graph \
--enable-eplb --eplb-rebalance-num-iterations 10000 --expert-distribution-recorder-mode stat --enable-expert-distribution-metrics --ep-num-redundant-experts 32  \
--enable-dp-attention --dp-size 16 --enable-dp-lm-head \
--nnodes 2 --node-rank 0 --dist-init-addr  192.168.0.171:5050 


python3 -m sglang.srt.disaggregation.mini_lb --prefill http://192.168.0.173:30300 --decode http://192.168.0.171:30300 --port 8000 

Limitations

A major limitation is the need to mask out a certain rank to simulate a failure. Ideally, this would involve directly killing a rank, but we have not yet implemented this due to the need for a proper plan to modify the scheduler.

Another limitation is that our current implementation does not support the CUDA graph. This is because the conditional masking logic is executed on the CPU, which is incompatible with CUDA graph. Hopefully, in Milestone 2, when the simulated rank failure is replaced with real rank termination, the need for conditional masking will be eliminated, and thus resolving this limitation.

Plan for Milestone 2

The plan for Milestone 2 includes the following goals:

  • Enhance the SGLang scheduler to support fault-aware scheduling, including implementing a fault-tolerant version of broadcast_pyobj that uses P2P communication among only the healthy ranks.
  • Expand testing coverage, including:
    • Enabling CUDA graph.
    • Using TP size > 1 for self-attention modules.
    • Simulating failure of more than one rank to verify robustness under multiple faults.

Open Discussion Points:

  • Scheduler modifications:
    • We need to decide on the best approach to modify the scheduler, including:
      • Which communication backend or protocol to use for fault-awar communication.
      • How to ensure compatibility with other models and existing SGLang workflows when introducing these changes.
  • Future roadmap for mooncake-ep:
    • Outline the next steps in the development of mooncake-ep.
    • Discuss whether and how to merge mooncake-ep into the DeepEP mainline for long-term maintenance and broader adoption.

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @UNIDY2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented the first milestone of elastic Expert Parallelism (EP) support for SGLang, specifically focusing on fault tolerance for Mixture-of-Experts (MoE) models like DeepSeek V3. This work enables the system to continue functioning correctly even when a GPU is simulated as broken. The core changes involve enhancing the DeepEP library for fault detection, extending the Expert Parallelism Load Balancer (EPLB) to dynamically rebalance experts to healthy ranks, and introducing a method to simulate GPU failures for testing purposes. This ensures that a majority of requests can still be processed successfully despite a simulated hardware fault.

Highlights

  • DeepEP Library Enhancement for Fault Detection: The DeepEP library has been enhanced to detect and manage faults. This includes introducing a timeout for P2P communication during token dispatch, marking a peer as failed if a timeout occurs. This failure status is then propagated to higher-level modules, with a master rank elected to monitor the health of all peer ranks.
  • EPLB Extension for Fault Tolerance: The Expert Parallelism Load Balancer (EPLB) has been extended to redistribute expert weights upon fault detection. When broken ranks are identified, the EPLB algorithm is adjusted to map logical experts only to healthy ranks. The original all-reduce operation is replaced with one among only the healthy ranks, and expert weights are loaded directly from disk.
  • Simulated GPU Failure for Testing: A mechanism has been implemented to simulate a broken GPU by masking out a certain rank. By setting the SGLANG_EP_AVOID_RANK environment variable, a specified rank will output zero tensors after a configurable number of executions (SGLANG_AVOID_EP_TRIGGER_AT), mimicking a GPU failure during inference.
  • New Benchmarking and Testing Utilities: New benchmarking and testing utilities, including bench.sh, gen_eplb.py, q.json, and send_req.py, have been added to facilitate comprehensive testing of the elastic EP features. A new test case, test/srt/test_elastic_ep.py, has also been introduced.
  • Dynamic Expert Location Updates and Weight Reloading: The system now dynamically updates expert location metadata and reloads expert weights from disk when a change in broken nodes is detected. This ensures that the model can adapt to new healthy configurations in real-time, maintaining operational integrity despite simulated failures.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the first milestone of elastic expert parallelism (EP) support, a crucial feature for fault tolerance in large-scale MoE model inference. The changes are extensive and well-structured, focusing on simulating a single GPU failure by masking a rank and ensuring the system continues to operate with the remaining healthy ranks. Key modifications include enhancing the DeepEP library for fault detection, extending the EPLB to redistribute expert weights considering faulty ranks, and adding logic to simulate failures for testing. The approach of simulating failure via masking as a first step is pragmatic. The code is generally of high quality, but I have identified a few areas for improvement, particularly in the testing and benchmark scripts, to enhance robustness and maintainability.

Comment on lines +20 to +146
"""发送单个请求并返回详细结果"""
start_time = time.time()
try:
response = requests.post(
url,
headers=HEADERS,
json=payload,
timeout=300, # 增加超时时间
)
response.raise_for_status() # 如果状态码不是 2xx,则引发 HTTPError

response_time = time.time() - start_time
result = {
"request_id": request_id,
"timestamp": time.time(),
"status_code": response.status_code,
"response_time": response_time,
"success": True,
"body": response.json(),
"error": None,
}
logging.info(f"Request #{request_id} succeeded in {response_time:.2f}s")

except requests.exceptions.RequestException as e:
response_time = time.time() - start_time
logging.error(f"Request #{request_id} failed after {response_time:.2f}s: {e}")
result = {
"request_id": request_id,
"timestamp": time.time(),
"status_code": e.response.status_code if e.response else None,
"response_time": response_time,
"success": False,
"body": None,
"error": str(e),
}

return result


def main(args):
"""执行并发测试并输出结果"""
logging.info(f"开始并发测试: {args.url}")
logging.info(f"配置: 并发数={args.concurrency}, 总请求数={args.total_requests}")

start_time = time.time()

# 加载请求体
try:
with open(args.payload_file, "r", encoding="utf-8") as f:
payload = json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
logging.error(f"无法加载请求体文件: {e}")
return

# 使用线程池执行并发请求
with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
futures = [
executor.submit(send_request, i, args.url, payload)
for i in range(1, args.total_requests + 1)
]
results = [
future.result() for future in tqdm(futures, total=args.total_requests)
]

total_time = time.time() - start_time

# 统计结果
success_count = sum(1 for r in results if r["success"])
failed_count = args.total_requests - success_count
response_times = [r["response_time"] for r in results if r["success"]]
avg_time = sum(response_times) / success_count if success_count > 0 else 0
max_time = max(response_times) if response_times else 0
min_time = min(response_times) if response_times else 0

# 输出统计摘要
summary = {
"total_time_s": round(total_time, 2),
"total_requests": args.total_requests,
"success_requests": success_count,
"failed_requests": failed_count,
"avg_response_time_s": round(avg_time, 2),
"max_response_time_s": round(max_time, 2),
"min_response_time_s": round(min_time, 2),
}
logging.info(f"测试完成: {json.dumps(summary, indent=2)}")

# 保存详细结果到JSON文件
if args.output_file:
with open(args.output_file, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
logging.info(f"详细结果已保存到: {args.output_file}")

# 输出失败请求详情
if failed_count > 0:
logging.warning("检测到失败的请求:")
for r in results:
if not r["success"]:
logging.warning(f" - 请求 #{r['request_id']}: {r['error']}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="并发请求测试工具")
parser.add_argument(
"-u",
"--url",
type=str,
default="http://127.0.0.1:30300/v1/chat/completions",
help="目标 URL",
)
parser.add_argument("-c", "--concurrency", type=int, default=10, help="并发数")
parser.add_argument(
"-n", "--total-requests", type=int, default=100, help="总请求数"
)
parser.add_argument(
"-p",
"--payload-file",
type=str,
required=True,
help="包含请求体的 JSON 文件",
)
parser.add_argument(
"-o",
"--output-file",
type=str,
default=f"request_results_{time.strftime('%Y%m%d_%H%M%S')}.json",
help="保存详细结果的 JSON 文件",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of issues in this script:

  1. The docstrings for send_request and main are in Chinese. For consistency with the rest of the codebase, they should be translated to English.
  2. The default value for --output-file is generated using time.strftime at module import time. This means the timestamp in the default filename will be fixed to when the script was first imported, not when it's run. This can lead to overwriting results if the script is called multiple times from another module within the same process. This should be generated within the main function.

Comment on lines +11 to +17
@classmethod
def setUpClass(cls):
pass

@classmethod
def tearDownClass(cls):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test case TestElasticEpGsm8k requires manual setup of the service, as indicated by the TODO and empty setUpClass/tearDownClass methods. This prevents the test from being run automatically in a CI environment, which is critical for ensuring regressions are not introduced. The test should be made self-contained by programmatically starting and stopping the server within the test class.

Comment on lines +173 to +176
fault_tolerant = "--fault-tolerant" in sys.argv
if fault_tolerant:
# remove the flag so prepare_server_args doesn't see it
sys.argv.remove("--fault-tolerant")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The way the --fault-tolerant flag is handled by directly manipulating sys.argv is fragile and can lead to unexpected behavior. A more robust approach would be to use argparse to handle this custom flag, separating it from the arguments intended for prepare_server_args.

Comment on lines +376 to +501
"""Computes a static dispatch map from logical to physical experts, prioritizing remote experts.
This function creates a dispatch map where each (GPU, logical expert) pair is assigned a
specific physical expert. The key difference from the default implementation is its preference
for assigning tasks to experts on different GPUs (remote experts) to potentially improve
workload distribution across the system, falling back to local experts only when no remote
options are available.
1. **Remote-First Assignment**: For each GPU, it identifies all available physical experts
located on other GPUs. If such experts exist, it selects one with the lowest current
load to handle the request.
2. **Load Balancing**: It maintains a load counter for each physical expert to ensure that
requests are distributed as evenly as possible among the available candidates.
3. **Local Fallback**: If a logical expert has no physical replicas on other GPUs, the
algorithm will assign a local expert (from the same GPU) instead.
4. **Deterministic Tie-Breaking**: The process is made deterministic by using a fixed seed.
When multiple experts have the same load, shuffling the candidates before selection
ensures fair tie-breaking.
Args:
logical_to_all_physical_map (torch.Tensor): A 3D tensor mapping each logical expert
to its physical replicas. Shape: `(num_layers, num_logical_experts, num_replicas)`.
num_gpus (int): The total number of GPUs in the expert parallel group.
num_physical_experts (int): The total number of physical experts.
ep_rank (int): The rank of the current process within the expert parallel group.
seed (int): A seed for the random number generator to ensure deterministic behavior.
Returns:
torch.Tensor: A 2D tensor for the current `ep_rank` that maps each logical expert
to a physical expert. Shape: `(num_layers, num_logical_experts)`.
"""
r = random.Random(seed)

# 计算每个GPU上的物理专家数量
num_local_physical_experts = num_physical_experts // num_gpus
# 获取映射表的维度信息
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype

# 创建一个用于存储最终分派映射的张量,并用-1填充
logical_to_rank_dispatch_physical_map = torch.full(
size=(num_gpus, num_layers, num_logical_experts),
fill_value=-1,
dtype=dtype,
)

# 遍历每一层
for layer_id in range(num_layers):
# 遍历该层中的每一个逻辑专家
for logical_expert_id in range(num_logical_experts):
# 获取当前逻辑专家的所有物理副本ID
candidate_physical_expert_ids = _logical_to_all_physical_raw(
logical_to_all_physical_map, layer_id, logical_expert_id
)
# 获取当前逻辑专家在所有GPU上的分派映射视图
output_partial = logical_to_rank_dispatch_physical_map[
:, layer_id, logical_expert_id
]

# 为每个物理专家初始化负载计数器
load = {p_id: 0 for p_id in candidate_physical_expert_ids}

# 遍历所有GPU,为每个GPU分配一个专家
for gpu_id in range(num_gpus):
# --- 远程优先选择阶段 ---
# 找出所有不位于当前GPU上的物理专家(即远程专家)
remote_experts = [
p_id
for p_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
p_id, num_local_physical_experts
)
!= gpu_id
]

# 如果存在远程专家,则从远程专家中选择;否则,从所有候选专家中选择(本地回退)
if remote_experts:
experts_to_choose_from = remote_experts
else:
experts_to_choose_from = candidate_physical_expert_ids

# 为了在负载相同时打破僵局,随机打乱候选专家列表
r.shuffle(experts_to_choose_from)

# --- 负载均衡选择 ---
# 从候选专家中选择一个当前负载最低的专家
chosen_expert = min(experts_to_choose_from, key=lambda p_id: load[p_id])

# 将选中的专家分配给当前GPU
output_partial[gpu_id] = chosen_expert
# 更新被选中专家的负载计数
load[chosen_expert] += 1

# 断言确保所有条目都已被成功分配
assert torch.all(logical_to_rank_dispatch_physical_map != -1)

# 获取原始张量的设备信息
device = logical_to_all_physical_map.device
# 返回属于当前ep_rank的分派映射表,并移动到正确的设备上
return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)


def compute_logical_to_rank_dispatch_physical_map_avoid_rank(
logical_to_all_physical_map: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
ep_rank: int,
avoid_rank: int,
seed: int = 42,
):
"""计算一个静态分派映射表,避免向特定rank调度任务。
此函数旨在创建一个分派映射,其中每个 (GPU, 逻辑专家) 对被分配一个特定的物理专家,
同时避免将任务分配给 `avoid_rank` 上的专家。如果过滤后没有可用专家,则会回退到使用所有可用专家。
1. **避免特定Rank**: 对于每个逻辑专家,它会首先过滤掉位于 `avoid_rank` 上的所有物理专家。
2. **负载均衡**: 在剩余的专家中,它使用负载计数器来确保请求在可用的候选专家中均匀分配。
3. **回退机制**: 如果过滤后没有可用的专家(例如,所有专家都在 `avoid_rank` 上),
该算法将回退到在所有候选专家(包括在 `avoid_rank` 上的专家)中进行选择,以确保任务能够被分配。
4. **确定性**: 整个过程通过固定的随机种子来保证确定性。
Args:
logical_to_all_physical_map (torch.Tensor): 一个三维张量,映射每个逻辑专家到其所有物理副本。
形状为 `(num_layers, num_logical_experts, num_replicas)`。
num_gpus (int): 专家并行组中的 GPU 总数。
num_physical_experts (int): 物理专家的总数。
ep_rank (int): 当前进程在专家并行组中的排名。
avoid_rank (int): 需要避免调度的 GPU rank。
seed (int): 用于随机数生成器的种子,以确保确定性行为。
Returns:
torch.Tensor: 一个二维张量,为当前的 `ep_rank` 映射每个逻辑专家到一个物理专家。
形状为 `(num_layers, num_logical_experts)`。
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstrings for the new functions compute_logical_to_rank_dispatch_physical_map_remote_first and compute_logical_to_rank_dispatch_physical_map_avoid_rank are in Chinese. To maintain consistency across the codebase and ensure they are understandable to all contributors, please translate them to English.

Comment on lines +1652 to +1657
gen = self.eplb_manager.rebalance()
while True:
try:
next(gen)
except StopIteration:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The while True loop with a try/except StopIteration block to exhaust the generator is a bit unconventional and less readable than the standard Python idiom. A simple for loop would be more Pythonic and achieve the same result with better clarity.

Suggested change
gen = self.eplb_manager.rebalance()
while True:
try:
next(gen)
except StopIteration:
break
for _ in self.eplb_manager.rebalance():
pass

@UNIDY2002
Copy link
Contributor Author

We are planning on restructuring the PR with a clearer abstraction and cleaner modifications. Stay tuned :)

@UNIDY2002 UNIDY2002 marked this pull request as draft August 12, 2025 15:56
@UNIDY2002 UNIDY2002 marked this pull request as ready for review September 9, 2025 02:20
@UNIDY2002 UNIDY2002 marked this pull request as draft September 9, 2025 02:21
@UNIDY2002
Copy link
Contributor Author

Now we are halfway towards Milestone 2.

Highlights

  • A new distributed backend (Mooncake) that supports fault tolerance during collective communications.
  • An adapted scheduler that keeps inference running even when some ranks fail.
  • An EPLB module that stays consistent during rank failures.

Next Steps

The immediate focus is on code cleanup, turning the new features into configurable options, and splitting the work into several smaller, mergeable PRs:

  • PR 1/N: Introduce Mooncake backend
    • Add an option to opt in the Mooncake distributed backend as a replacement for NCCL/Gloo.
    • The main change is specifying backend="mooncake" when calling torch.distributed.new_group; the rest of the project remains largely untouched (non-intrusive).
    • Affected: parallel_state.py
  • PR 2/N: Introduce Mooncake EP
    • Add token_dispatcher/mooncake.py to wrap Mooncake EP, allowing users to opt in the Mooncake EP token dispatcher.
    • Affected: token_dispatcher/, expert_location.py
  • PR 3/N: Adapt EPLB algorithm for fault conditions
    • Update EPLB to avoid assigning experts to faulty ranks, using the broken-ranks status captured by Mooncake.
    • Affected: eplb_algorithms/
  • PR 4/N: Redistribute experts upon rank failures
    • On rank failure, trigger expert load balancing so that every expert has at least one replica on a healthy rank.
    • After this PR, killing a rank should not interrupt inference, and most requests should still produce correct results.
    • Affected: model_runner.py, scheduler.py
  • PR 5/N: Retry/re-dispatch in the scheduler
    • Make the scheduler aware of request failures caused by broken ranks, with automatic retries and rerouting of follow-up requests to healthy ranks.
    • Affected: scheduler.py, data_parallel_controller.py

We now have a primary proof-of-concept fault tolerance implementation ready.
If this PR breakdown looks good, we will start the merging process in the coming weeks.

cc @zhyncs @fzyzcjy @ch-wan @james0zan

@UNIDY2002 UNIDY2002 changed the title Elastic EP Support (Milestone 1) Elastic EP Support (Milestone 1 & 2) Sep 9, 2025
HanHan009527 and others added 4 commits October 16, 2025 01:00
fix

fix

fix

fix

fix

fix

fix

ut

ut

ut

fix

fit
fi

fi

fix

fix

fix

fix

fix

fix

fix

fix

fix

fit

fix
@UNIDY2002
Copy link
Contributor Author

This draft PR is superseded by #11657

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants