Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
bbe48fa
[None][feat] Add DWDP (Distributed Weight Data Parallelism) support f…
tianyuz-nv Mar 19, 2026
7e4a855
Merge upstream/main into dwdp_productization
tianyuz-nv Mar 23, 2026
f76d026
Register DWDP accuracy test in CI test lists
tianyuz-nv Mar 23, 2026
f153099
Fix CI: remove forbidden from_dict, init helper, apply pre-commit for…
tianyuz-nv Mar 23, 2026
38d6e92
Fix CI and add env var config for disaggregated benchmark
tianyuz-nv Mar 24, 2026
d52a3cf
Merge remote-tracking branch 'upstream/main' into dwdp_productization
tianyuz-nv Mar 24, 2026
f52716d
Remove unused helper variable to fix ruff F841
tianyuz-nv Mar 24, 2026
64a7c7d
Merge branch 'main' into dwdp_productization
Kefeng-Duan Mar 28, 2026
47618f3
Merge branch 'main' into dwdp_productization
Kefeng-Duan Mar 29, 2026
b10dc89
Improve DwdpConfig docstring per review feedback
tianyuz-nv Mar 30, 2026
7040370
Decouple DWDP from mainline disagg scripts per reviewer feedback
tianyuz-nv Mar 30, 2026
28390e5
fix(moe): remove commented-out barrier in moeA2AInitializeOp
tianyuz-nv Mar 30, 2026
a4085ea
fix(dwdp): move record_compute_and_prefetch_next to per-layer level
tianyuz-nv Mar 30, 2026
a8a0a8f
style: fix pre-commit formatting for DWDP files
tianyuz-nv Mar 30, 2026
26e3e9f
Merge remote-tracking branch 'upstream/main' into dwdp_productization
tianyuz-nv Mar 31, 2026
85def53
refactor(dwdp): address reviewer feedback on DWDP config and lifecycle
tianyuz-nv Mar 31, 2026
afcd735
test(api_stability): align llm.yaml dwdp_config with Optional + proto…
tianyuz-nv Mar 31, 2026
be12482
Merge remote-tracking branch 'upstream/main' into dwdp_productization
tianyuz-nv Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp
Comment thread
tianyuz-nv marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank,

// Synchronize among ranks
cudaDeviceSynchronize();
tensorrt_llm::mpi::MpiComm::world().barrier();
// tensorrt_llm::mpi::MpiComm::world().barrier();
Comment thread
tianyuz-nv marked this conversation as resolved.
Outdated
tensorrt_llm::mpi::MpiComm::session().barrier();

return metainfo;
}
Expand Down
5 changes: 5 additions & 0 deletions examples/disaggregated/slurm/benchmark/disaggr_torch.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ replace_placeholder "${start_server_cmds_base_file}" "${all_nodes_str}" "${start
server_config_base_file=${full_logdir}/server_config_base.yaml
server_config_file=${full_logdir}/server_config.yaml
replace_placeholder "${server_config_base_file}" "${all_nodes_str}" "${server_config_file}"
mpi_worker_config_base_file=${full_logdir}/mpi_worker_config_base.yaml
Comment thread
tianyuz-nv marked this conversation as resolved.
Outdated
mpi_worker_config_file=${full_logdir}/mpi_worker_config.yaml
if [ -f "${mpi_worker_config_base_file}" ]; then
replace_placeholder "${mpi_worker_config_base_file}" "${all_nodes_str}" "${mpi_worker_config_file}"
fi
client_cmds_base_file=${full_logdir}/client_cmds_base.sh
client_cmds_file=${full_logdir}/client_cmds.sh
replace_placeholder "${client_cmds_base_file}" "${all_nodes_str}" "${client_cmds_file}"
Expand Down
61 changes: 61 additions & 0 deletions examples/disaggregated/slurm/benchmark/start_worker_dwdp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#! /bin/bash
set -u
set -e
set -x

config_file=${1}
numa_bind=${2}
log_dir=${3}
enable_nsys=${4}
ctx_profile_range=${5}
gen_profile_range=${6}
num_ctx_gpus=${7}
ctx_worker_env_var=${8}
gen_worker_env_var=${9}

unset UCX_NET_DEVICES
unset UCX_TLS

echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname)"

if [ "${SLURM_PROCID}" -lt "${num_ctx_gpus}" ]; then
worker_role="CTX"
worker_env_var=${ctx_worker_env_var}
profile_range=${ctx_profile_range}
else
worker_role="GEN"
worker_env_var=${gen_worker_env_var}
profile_range=${gen_profile_range}
fi

echo "worker_role: ${worker_role}, profile_range: ${profile_range}"

for env_var in ${worker_env_var}; do
export "${env_var}"
echo "Exported: ${env_var}"
done

if [ "${numa_bind}" = "true" ]; then
numa_bind_cmd="numactl -m 0,1"
echo "numactl -m 0,1 - Only allocate memory from nodes on GB200/GB300 NVL72"
else
numa_bind_cmd=""
echo "Not binding memory. If on GB200/GB300 NVL72, use \"numactl -m 0,1\" to only allocate memory from nodes."
fi

echo "config_file: ${config_file}"

nsys_prefix=""
if [ "${enable_nsys}" != "true" ]; then
echo "nsys is not enabled, start normal flow"
else
nsys_file=${log_dir}/nsys_worker_proc_${worker_role}_${SLURM_PROCID}
export TLLM_PROFILE_RECORD_GC=1
export TLLM_NVTX_DEBUG=1
export NSYS_MPI_STORE_TEAMS_PER_RANK=1
export TLLM_PROFILE_START_STOP=${profile_range}
echo "nsys is enabled on ${worker_role} ranks, TLLM_PROFILE_START_STOP=${profile_range}"
nsys_prefix="nsys profile -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
fi

${nsys_prefix} ${numa_bind_cmd} trtllm-serve disaggregated_mpi_worker -c ${config_file}
212 changes: 159 additions & 53 deletions examples/disaggregated/slurm/benchmark/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,45 @@ def save_worker_config(worker_config, output_path):
yaml.dump(worker_config, f, default_flow_style=False)


def generate_mpi_worker_config(worker_config, allocations, env_config,
Comment thread
tianyuz-nv marked this conversation as resolved.
Outdated
disagg_hostname, disagg_port, output_path):
"""Generate a config YAML compatible with ``trtllm-serve disaggregated_mpi_worker``.
"""

def _build_urls(server_type):
urls = []
for server_id in sorted(allocations.get(server_type, {}).keys()):
inst = allocations[server_type][server_id]
host = list(inst["nodes"].keys())[0]
urls.append(f"{host}:{inst['port']}")
return urls

ctx_urls = _build_urls("CTX")
gen_urls = _build_urls("GEN")

ctx_section = dict(worker_config['ctx'])
ctx_section['num_instances'] = len(ctx_urls)
ctx_section['urls'] = ctx_urls

gen_section = dict(worker_config['gen'])
gen_section['num_instances'] = len(gen_urls)
gen_section['urls'] = gen_urls

config = {
'model': env_config['model_path'],
'hostname': disagg_hostname,
'port': disagg_port,
'backend': 'pytorch',
'max_retries': 100,
'context_servers': ctx_section,
'generation_servers': gen_section,
}

os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False, sort_keys=False)


def calculate_nodes(world_size, num_servers, gpus_per_node):
"""Calculate required nodes based on world size and server count."""
return math.ceil(world_size * num_servers / gpus_per_node)
Expand Down Expand Up @@ -105,10 +144,13 @@ def assign_servers(
server_allocations[server_type][i] = server_allocation
port += 1

assign_servers(allocations, "GEN", num_gen_servers, gen_world_size,
gpus_per_node)
# Keep the allocation order aligned with disagg_utils, which builds
# server_configs as ctx_cfgs + gen_cfgs and assigns rank offsets in that
# same order during split_world_comm().
assign_servers(allocations, "CTX", num_ctx_servers, ctx_world_size,
gpus_per_node)
assign_servers(allocations, "GEN", num_gen_servers, gen_world_size,
gpus_per_node)

return allocations

Expand Down Expand Up @@ -406,6 +448,13 @@ def submit_job(config, log_dir, dry_run):
total_nodes = ctx_nodes + gen_nodes
total_tasks = total_nodes * gpus_per_node

# Detect DWDP mode: when enabled, use a single srun with
# trtllm-serve disaggregated_mpi_worker instead of per-instance sruns
dwdp_enabled = worker_config.get('ctx', {}).get('dwdp_config',
{}).get('enabled', False)
dwdp_size = worker_config.get('ctx', {}).get('dwdp_config',
{}).get('dwdp_size', 1)

# Generate log directory path based on configuration
isl = benchmark_config['input_length']
osl = benchmark_config['output_length']
Expand Down Expand Up @@ -434,10 +483,13 @@ def submit_job(config, log_dir, dry_run):
log_base = os.path.join(log_base, f"{date_prefix}/{isl}-{osl}")

# Determine directory suffix based on attention_dp
if gen_enable_attention_dp:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
if dwdp_enabled:
Comment thread
tianyuz-nv marked this conversation as resolved.
Outdated
dir_suffix = f"disagg_ctx{ctx_num}_dwdp{dwdp_size}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
else:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
if gen_enable_attention_dp:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
else:
dir_suffix = f"disagg_ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"

# Create full log directory path
log_dir = os.path.join(log_base, dir_suffix)
Expand Down Expand Up @@ -506,54 +558,108 @@ def submit_job(config, log_dir, dry_run):
}
}

# Generate start worker commands with placeholder hostnames
for server_type in allocations.keys():
server_cfg = server_configs[server_type]

for server_id in allocations[server_type].keys():
allocation = allocations[server_type][server_id]
# Get GPU IDs for this server from allocation
# When multi-node, all nodes have same device list, so use first node [0]
gpu_ids = list(allocation["nodes"].values())[0]

# Build environment for this worker
cuda_devices = ','.join(map(str, gpu_ids))
worker_env = build_worker_environment(
worker_config=worker_config,
env_config=env_config,
role=server_type,
benchmark_mode=benchmark_config['mode'],
nsys_on=profiling_config['nsys_on'],
profile_range=server_cfg['profile_range'],
concurrency=benchmark_config['concurrency_list'].split(',')[0],
)
export_str = format_export_string(worker_env)

# Use script_dir for start_worker.sh
cmd = [
"srun -l",
f"--nodelist {','.join(allocation['nodes'].keys())}",
f"-N {len(allocation['nodes'])}",
f"--ntasks {server_cfg['world_size']}",
f"--ntasks-per-node {gpus_per_node}",
f"--export=\"{export_str}\"",
f"--container-image {env_config['container_image']}",
f"--container-name {container_name}",
f"--container-mounts {container_mount_str}",
"--no-container-mount-home --mpi=pmix --overlap",
f"bash {os.path.join(script_dir, 'start_worker.sh')}",
server_type,
str(server_id),
env_config['model_path'],
str(allocation["port"]),
str(slurm_config['numa_bind']).lower(),
log_dir,
str(profiling_config['nsys_on']).lower(),
server_cfg['config_path'],
cuda_devices,
f"&> {log_dir}/3_output_{server_type}_{server_id}.log &",
]
start_server_cmds.append(" ".join(cmd))
if dwdp_enabled:
# --- DWDP mode: single srun with disaggregated_mpi_worker ---
mpi_config_base_path = os.path.join(log_dir,
'mpi_worker_config_base.yaml')
mpi_config_path = os.path.join(log_dir, 'mpi_worker_config.yaml')
generate_mpi_worker_config(worker_config, allocations, env_config,
disagg_server_hostname, disagg_server_port,
mpi_config_base_path)

# Nodelist: CTX nodes first, then GEN nodes (matches
# split_world_comm order: server_configs = ctx_cfgs + gen_cfgs)
ctx_node_list = []
for sid in sorted(allocations.get("CTX", {}).keys()):
for node in allocations["CTX"][sid]["nodes"]:
if node not in ctx_node_list:
ctx_node_list.append(node)
gen_node_list = []
for sid in sorted(allocations.get("GEN", {}).keys()):
for node in allocations["GEN"][sid]["nodes"]:
if node not in gen_node_list:
gen_node_list.append(node)
mpi_nodelist = ctx_node_list + gen_node_list
total_mpi_tasks = ctx_num * ctx_world_size + gen_num * gen_world_size
mpi_num_nodes = len(mpi_nodelist)
num_ctx_gpus = ctx_num * ctx_world_size
worker_env_var = env_config.get('worker_env_var', '')
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
dwdp_ctx_worker_env_var = worker_env_var + \
(f" {ctx_worker_env_var}" if ctx_worker_env_var else "")
dwdp_gen_worker_env_var = worker_env_var + \
(f" {gen_worker_env_var}" if gen_worker_env_var else "")
Comment thread
tianyuz-nv marked this conversation as resolved.
Outdated

cmd = [
"srun -l",
f"--nodelist {','.join(mpi_nodelist)}",
f"-N {mpi_num_nodes}",
f"--ntasks {total_mpi_tasks}",
f"--ntasks-per-node {gpus_per_node}",
f"--container-image {env_config['container_image']}",
f"--container-name {container_name}",
f"--container-mounts {container_mount_str}",
"--no-container-mount-home --mpi=pmix --overlap",
f"bash {os.path.join(script_dir, 'start_worker_dwdp.sh')}",
mpi_config_path,
str(slurm_config['numa_bind']).lower(),
log_dir,
str(profiling_config['nsys_on']).lower(),
f"'{profiling_config['ctx_profile_range']}'",
f"'{profiling_config['gen_profile_range']}'",
str(num_ctx_gpus),
f"'{dwdp_ctx_worker_env_var}'",
f"'{dwdp_gen_worker_env_var}'",
f"&> {log_dir}/3_output_workers.log &",
]
start_server_cmds.append(" ".join(cmd))
else:
# --- Standard mode: per-instance srun ---
for server_type in allocations.keys():
server_cfg = server_configs[server_type]

for server_id in allocations[server_type].keys():
allocation = allocations[server_type][server_id]
gpu_ids = list(allocation["nodes"].values())[0]

cuda_devices = ','.join(map(str, gpu_ids))
worker_env = build_worker_environment(
worker_config=worker_config,
env_config=env_config,
role=server_type,
benchmark_mode=benchmark_config['mode'],
nsys_on=profiling_config['nsys_on'],
profile_range=server_cfg['profile_range'],
concurrency=benchmark_config['concurrency_list'].split(',')
[0],
)
export_str = format_export_string(worker_env)

cmd = [
"srun -l",
f"--nodelist {','.join(allocation['nodes'].keys())}",
f"-N {len(allocation['nodes'])}",
f"--ntasks {server_cfg['world_size']}",
f"--ntasks-per-node {gpus_per_node}",
f"--export=\"{export_str}\"",
f"--container-image {env_config['container_image']}",
f"--container-name {container_name}",
f"--container-mounts {container_mount_str}",
"--no-container-mount-home --mpi=pmix --overlap",
f"bash {os.path.join(script_dir, 'start_worker.sh')}",
server_type,
str(server_id),
env_config['model_path'],
str(allocation["port"]),
str(slurm_config['numa_bind']).lower(),
log_dir,
str(profiling_config['nsys_on']).lower(),
server_cfg['config_path'],
cuda_devices,
f"&> {log_dir}/3_output_{server_type}_{server_id}.log &",
]
start_server_cmds.append(" ".join(cmd))

# Generate start server commands (use script_dir for start_server.sh)
server_env = build_server_environment(env_config, benchmark_config['mode'])
Expand Down
Loading
Loading