Skip to content

Commit bc53e6e

Browse files
kaiyuxlancelly
authored andcommitted
doc: Add README for wide EP (NVIDIA#6356)
Signed-off-by: Kaiyu Xie <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 1e612fe commit bc53e6e

File tree

5 files changed

+106
-28
lines changed

5 files changed

+106
-28
lines changed

docs/source/blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
By NVIDIA TensorRT-LLM Team
44

55
## Table of Contents
6-
- [Scaling Expert Parallelism in TensorRT-LLM (Part 1: Design and Implementation of Large-scale EP)](#scaling-expert-parallelism-in-tensorrt-llmpart-1-design-and-implementation-of-large-scale-ep)
6+
- [Scaling Expert Parallelism in TensorRT-LLM (Part 1: Design and Implementation of Large-scale EP)](#scaling-expert-parallelism-in-tensorrt-llm-part-1-design-and-implementation-of-large-scale-ep)
77
- [Table of Contents](#table-of-contents)
88
- [Motivation for large-scale EP](#motivation-for-large-scale-ep)
99
- [Observations over one machine translation dataset](#observations-over-one-machine-translation-dataset)
@@ -15,8 +15,8 @@ By NVIDIA TensorRT-LLM Team
1515
- [EP Load Balancer](#ep-load-balancer)
1616
- [Python Interface](#python-interface)
1717
- [C++ extension](#c-extension)
18-
- [Core implementations of host side logics](#core-implementations-of-host-side-logics)
19-
- [Core implementations of GPU side logics](#core-implementations-of-gpu-side-logics)
18+
- [Core implementations of the host logic](#core-implementations-of-the-host-logic)
19+
- [Core implementations of the GPU logic](#core-implementations-of-the-gpu-logic)
2020
- [Online EP Load Balancer](#online-ep-load-balancer)
2121
- [Offline EP Load Balancer](#offline-ep-load-balancer)
2222
- [E2E evaluation](#e2e-evaluation)
@@ -516,7 +516,9 @@ Clearly in Figure 25, we can see that EPLB brings a clear performance improvemen
516516

517517
## Reproducing steps
518518
Currently to run through the reproducing steps described in this section, please, use this [feature branch](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/large-ep/tensorrt_llm). It will get merged to the main branch soon.
519+
519520
### The effect of EP Load Balancer
521+
520522
Please, refer to the [EP Load Balancer example](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/large-ep/examples/ep_load_balancer) for how to reproduce the results for the offline EP Load Balancer.
521523

522524
##### Step 1: Run inference and collect statistics

examples/disaggregated/slurm/gen_yaml.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,16 @@ def gen_config_file(config_path: str,
173173
'max_batch_size': ctx_batch_size,
174174
'max_num_tokens': ctx_max_num_tokens,
175175
'max_seq_len': 1152,
176-
'free_gpu_memory_fraction': 0.85,
177176
'tensor_parallel_size': ctx_tp_size,
178177
'moe_expert_parallel_size': ctx_tp_size,
179178
'enable_attention_dp': ctx_enable_attention_dp,
180179
'pipeline_parallel_size': 1,
181180
'print_iter_log': True,
182181
'disable_overlap_scheduler': True,
183-
'kv_cache_dtype': 'fp8',
182+
'kv_cache_config': {
183+
'free_gpu_memory_fraction': 0.85,
184+
'dtype': 'fp8',
185+
},
184186
'cache_transceiver_config': {
185187
'backend': 'default',
186188
'max_tokens_in_buffer': 8320,
@@ -195,14 +197,18 @@ def gen_config_file(config_path: str,
195197
'max_batch_size': gen_batch_size,
196198
'max_num_tokens': gen_max_num_tokens,
197199
'max_seq_len': 2176,
198-
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
199200
'cuda_graph_config': {
200201
'enable_padding': True,
201202
'batch_sizes': gen_cuda_graph_batch_sizes,
202203
},
203204
'print_iter_log': True,
204-
'kv_cache_dtype': 'fp8',
205-
'moe_backend': gen_moe_backend,
205+
'kv_cache_config': {
206+
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
207+
'dtype': 'fp8',
208+
},
209+
'moe_config': {
210+
'backend': gen_moe_backend,
211+
},
206212
'cache_transceiver_config': {
207213
'backend': 'default',
208214
'max_tokens_in_buffer': 8320,
@@ -242,8 +248,8 @@ def gen_config_file(config_path: str,
242248
f,
243249
default_flow_style=False,
244250
sort_keys=False)
245-
config['generation_servers'][
246-
'moe_load_balancer'] = moe_load_balancer_file
251+
config['generation_servers']['moe_config'][
252+
'load_balancer'] = moe_load_balancer_file
247253

248254
if mtp_size > 0:
249255
config['context_servers']['speculative_config'] = {

examples/disaggregated/slurm/submit.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#!/bin/bash
22

3-
# !!!
4-
# Make sure that SLURM parameters are correctly set in `disaggr_torch.slurm` before executing this script.
5-
# !!!
3+
echo "Make sure that SLURM parameters are correctly set in \`disaggr_torch.slurm\` before executing this script."
64

75
# concurrency 8
86
concurrency=8

examples/wide_ep/README.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Wide Expert Parallelism (Wide-EP) in TensorRT-LLM
2+
3+
TensorRT-LLM's Wide Expert Parallelism (Wide-EP) feature enables efficient inference of large-scale Mixture-of-Experts (MoE) models by scaling expert parallelism beyond traditional limits. This feature addresses the inherent workload imbalance challenges in large-scale MoE models and provides both offline and online load balancing capabilities.
4+
5+
## Overview
6+
7+
Large-scale MoE models like DeepSeek-V3/R1, LLaMA4, and Qwen3 use fine-grained expert designs that introduce new challenges for inference systems:
8+
9+
- **High memory demands** for expert weights
10+
- **Inherent expert-level workload imbalance** due to sparse execution patterns
11+
- **Communication overhead** in distributed expert parallelism
12+
13+
Wide-EP solves these challenges through:
14+
15+
- **Custom EP communication kernels** optimized for NVIDIA GB200 Multi-Node NVLink (MNNVL)
16+
- **Expert Parallelism Load Balancer (EPLB)** with both offline and online modes
17+
- **Dynamic expert placement and replication** strategies
18+
- **Layer-wise weight redistribution** to minimize inference disruption
19+
20+
## Quick Start
21+
22+
### 1. Configurations
23+
24+
An example yaml file to enable wide EP:
25+
```yaml
26+
moe_config:
27+
backend: WIDEEP
28+
max_num_tokens: 9216
29+
load_balancer: moe_load_balancer.yaml # (optional) enable load balancer
30+
```
31+
32+
| Parameter | Description | Default | Notes |
33+
|-----------|-------------|---------|-------|
34+
| `backend` | MoE backend type | `CUTLASS` | Set to `WIDEEP` to enable wide EP |
35+
| `max_num_tokens` | If set, at most max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. | `None` | If the number of tokens exceeds max_num_tokens, the input tensors will be split into chunks and a for loop will be used. |
36+
| `load_balancer` | Configuration for MoE load balancing | `None` | Set path to the yaml file |
37+
38+
#### Load Balancer Configuration
39+
40+
An example `moe_load_balancer.yaml` file to configure online EP balancer:
41+
```yaml
42+
num_slots: 288
43+
layer_updates_per_iter: 1
44+
```
45+
46+
| Parameter | Description | Default | Notes |
47+
|-----------|-------------|---------|-------|
48+
| `num_slots` | Total number of expert slots | `None` | Must be ≥ total experts |
49+
| `layer_updates_per_iter` | Number of layers updated per iteration | `0` | `0` = offline, `>0` = online |
50+
51+
Refer to the [ep_load_balancer](./ep_load_balancer/) directory for more details on EP load balancer.
52+
53+
### 2. Execute Wide-EP on SLURM Clusters
54+
55+
Refer to the [slurm_scripts](./slurm_scripts/) directory, which reuses [disaggregated slurm scripts](../disaggregated/slurm/) to automatically generate configuration files and submit jobs to SLURM clusters.
56+
57+
## Trouble shooting
58+
59+
### Transparent HugePages failure
60+
61+
When getting exception `madvise(MADV_HUGEPAGE) failed.`, check if Transparent Hugepages has been enabled.
62+
```bash
63+
>$ cat /sys/kernel/mm/transparent_hugepage/enabled
64+
always [madvise] never
65+
>$ cat /sys/kernel/mm/transparent_hugepage/defrag
66+
always defer defer+madvise [madvise] never
67+
```
68+
If `never` is highlighted, enable Transparent HugePages by the following command.
69+
```bash
70+
echo madvise > /sys/kernel/mm/transparent_hugepage/enabled
71+
```
72+
73+
### Disaggregated serving related issues
74+
75+
Refer to the [Troubleshooting and FAQ](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/disaggregated-service.md#troubleshooting-and-faq) section of Disaggregated-Service.
76+
77+
## References
78+
79+
- [Technical Blog: Scaling Expert Parallelism in TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.md)
80+
81+
For detailed implementation examples and advanced usage, see the subdirectories:
82+
- [`ep_load_balancer/`](ep_load_balancer/): Load balancing tools and examples
83+
- [`slurm_scripts/`](slurm_scripts/): Cluster deployment scripts

examples/wide_ep/slurm_scripts/submit.sh

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,19 @@
11
#!/bin/bash
22

3-
# !!!
4-
# Please find the `disaggr_torch.slurm` script in the `examples/disaggregated/slurm/` directory.
5-
# Make sure that SLURM parameters are correctly set in `disaggr_torch.slurm` before executing this script.
6-
# !!!
3+
echo "Please find the \`disaggr_torch.slurm\` script in the \`examples/disaggregated/slurm/\` directory."
4+
echo "Make sure that SLURM parameters are correctly set in \`disaggr_torch.slurm\` before executing this script."
75

86
mtp_size=0
97
ntasks_per_node=4 # 4 GPUs per GB200 node
108

11-
# dep8
12-
for b in 1 64 1024; do
13-
concurrency=$((b * 8))
14-
ctx_num=$(((concurrency + 5499)/5500))
15-
total_node_num=$((ctx_num + 2))
16-
ntasks=$((total_node_num * ntasks_per_node))
17-
sbatch --nodes=${total_node_num} --ntasks=${ntasks} --ntasks-per-node=${ntasks_per_node} --segment=${total_node_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 8 1024 1024 true "0.8" 0 "$mtp_size" "$concurrency"
18-
done
19-
209
# dep16 eplb0, 256, 288
2110
for b in 1 64 1024; do
2211
concurrency=$((b * 16))
2312
ctx_num=$(((concurrency + 5499)/5500))
2413
total_node_num=$((ctx_num + 4))
2514
ntasks=$((total_node_num * ntasks_per_node))
26-
sbatch --nodes=${total_node_num} --ntasks=${ntasks} --ntasks-per-node=${ntasks_per_node} --segment=${total_node_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 0 "$mtp_size" "$concurrency"
27-
sbatch --nodes=${total_node_num} --ntasks=${ntasks} --ntasks-per-node=${ntasks_per_node} --segment=${total_node_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 256 "$mtp_size" "$concurrency"
15+
# sbatch --nodes=${total_node_num} --ntasks=${ntasks} --ntasks-per-node=${ntasks_per_node} --segment=${total_node_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 0 "$mtp_size" "$concurrency"
16+
# sbatch --nodes=${total_node_num} --ntasks=${ntasks} --ntasks-per-node=${ntasks_per_node} --segment=${total_node_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 256 "$mtp_size" "$concurrency"
2817
sbatch --nodes=${total_node_num} --ntasks=${ntasks} --ntasks-per-node=${ntasks_per_node} --segment=${total_node_num} disaggr_torch.slurm ${ctx_num} 4 4 4480 true 1 16 1024 1024 true "0.7" 288 "$mtp_size" "$concurrency"
2918
done
3019

0 commit comments

Comments
 (0)