Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ compile_commands.json
.devcontainer/docker-compose.override.yml

# Enroot sqsh files
enroot/sw-tensorrt-docker+*.sqsh
enroot/tensorrt_llm.devel.sqsh
95 changes: 95 additions & 0 deletions examples/layer_wise_benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Layer-wise Benchmarks

## Generate profiles

### Run with MPI

**Step 1:** Start a container using Docker, Enroot or others. Please refer to `../../jenkins/current_image_tags.properties` for the Docker image URI.

**Step 2:** In the container, install `tensorrt_llm`:

```bash
pip install -e ../..
```

**Step 3:** In the container, run benchmarks and generate profiles:

```bash
# Run DeepSeek-R1
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml

# Run DeepSeek-V3.2-Exp
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM

# Run DeepSeek-V3.2-Exp with 32k context length
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --max-num-tokens $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769

# Run with attention TP
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp

# Run with attention TP and TRTLLMGen
NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM

# Run with MTP3
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --batch-size 32 --seq-len-q 4

# Run 4 layers
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --layer-indices 5,6,7,8
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --layer-indices 5,6,7,8

# Scale DEP=16 MNNVL to 4 GPUs: reduce the number of experts, uses MNNVL A2A if applicable
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-backend WIDEEP

# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp

# Run with DeepEP A2A
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP
```

### Run with Slurm

> Tips: If you have a running job with environment installed, please skip step 1 and 2 and go straight to step 3. In this case, your job must be run with `--container-name aaa`, and if the container name is not "layer_wise_benchmarks" please `export CONTAINER_NAME=aaa`.

**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`:

```bash
SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
```

Please fill the variables in `./slurm_alloc.sh`.

**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node:

```bash
SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh
```

It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once.

**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes:

```bash
# Run DeepSeek-R1 with wide ep: uses MNNVL A2A if applicable
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP

# Run with attention TP and TRTLLMGen
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM

# Run with DeepEPLowLatency
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP

# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job
SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run_single.sh config_ctx.yaml
SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run_single.sh config_ctx.yaml
```

## Parse profiles

Coming soon.
21 changes: 21 additions & 0 deletions examples/layer_wise_benchmarks/config_ctx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
model: nvidia/DeepSeek-R1-0528-FP4-v2
layer_indices: [5]
run_type: CTX
scaled_from: null

# KV cache related args
tokens_per_block: 32
max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true

# Model init args
max_num_tokens: 20480
moe_backend: CUTLASS
use_cuda_graph: false

# Per iteration args
batch_size: 1
seq_len_q: 8193
seq_len_kv_cache: 0
balance_method: Balanced
balance_ratio: null
21 changes: 21 additions & 0 deletions examples/layer_wise_benchmarks/config_gen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
model: nvidia/DeepSeek-R1-0528-FP4-v2
layer_indices: [5]
run_type: GEN
scaled_from: null

# KV cache related args
tokens_per_block: 32
max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true

# Model init args
max_num_tokens: 4096 # MTP3 as max
moe_backend: CUTLASS
use_cuda_graph: true

# Per iteration args
batch_size: 128
seq_len_q: 1 # Set to (1 + MTP)
seq_len_kv_cache: 8193
balance_method: Balanced
balance_ratio: null
10 changes: 10 additions & 0 deletions examples/layer_wise_benchmarks/mpi_launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

set -euo pipefail

# Clear slurm envs
unset $(env | grep -i slurm | awk -F'=' '{print $1}')
unset $(env | grep MPI | awk -F'=' '{print $1}')

set -x
mpirun --allow-run-as-root --np ${NP} "$@"
159 changes: 159 additions & 0 deletions examples/layer_wise_benchmarks/run_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import argparse

import numpy as np
import nvtx
import torch
import yaml

from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import (
BalanceMethod, DeepSeekV3Runner)


def comma_separated_ints(s):
return [int(x) for x in s.split(",")]


# Parse cmdline
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("--model", type=str, help="Pretrained model name or path")
parser.add_argument(
"--layer-indices",
type=comma_separated_ints,
help="Comma separated indices of layers, should be a contiguous range")
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
parser.add_argument("--scaled-from", type=int)
# KV cache related args
parser.add_argument("--tokens-per-block", type=int)
parser.add_argument("--max-seq-len", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--enable-attention-dp",
action="store_true",
dest="enable_attention_dp")
group.add_argument("--no-enable-attention-dp",
action="store_false",
dest="enable_attention_dp")
parser.set_defaults(enable_attention_dp=None)
# Model init args
parser.add_argument("--max-num-tokens", type=int)
parser.add_argument("--moe-backend", type=str)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--use-cuda-graph",
action="store_true",
dest="use_cuda_graph")
group.add_argument("--no-use-cuda-graph",
action="store_false",
dest="use_cuda_graph")
parser.set_defaults(use_cuda_graph=None)
# Per iteration args
parser.add_argument("--batch-size", type=int)
parser.add_argument("--seq-len-q", type=int)
parser.add_argument("--seq-len-kv-cache", type=int)
parser.add_argument("--balance-method", type=str)
parser.add_argument("--balance-ratio", type=float)
args = parser.parse_args()
with open(args.config_path) as f:
config = yaml.safe_load(f)
del args.config_path
for k, v in vars(args).items():
if v is None:
setattr(args, k, config[k])
print(args)

# MPI args
rank = mpi_rank()
world_size = mpi_world_size()
local_rank = local_mpi_rank()
torch.cuda.set_device(local_rank)

# Create KV cache manager
mapping = DeepSeekV3Runner.create_mapping(
enable_attention_dp=args.enable_attention_dp)
max_batch_size = 2048
kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager(
args.model,
mapping,
tokens_per_block=args.tokens_per_block,
max_batch_size=max_batch_size,
max_seq_len=args.max_seq_len,
layer_indices=args.layer_indices)
attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8)

# Create other global objects
AutoTuner.get().clear_cache()
capture_stream = torch.cuda.Stream()

# Create Runner
runner = DeepSeekV3Runner(args.model,
mapping,
moe_backend=args.moe_backend,
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
use_cuda_graph=args.use_cuda_graph)

# Warm up
assert args.batch_size <= max_batch_size
assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len
run_pack = runner.create_run_pack(args.run_type,
batch_size=args.batch_size,
seq_len_q=args.seq_len_q,
seq_len_kv_cache=args.seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace)
runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method],
balance_ratio=args.balance_ratio)
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
run_pack()
with autotune():
run_pack()
torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize()

# Profile: capture graph and replay it
torch.cuda.cudart().cudaProfilerStart()
if args.use_cuda_graph:
with with_multi_stream(True):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g,
stream=capture_stream,
capture_error_mode="global"):
run_pack()

warmup_times = 20
run_times = 100
events = [
torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times + run_times + 1)
]
for i in range(warmup_times + run_times):
events[i].record()
with nvtx.annotate(
f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
if args.use_cuda_graph:
g.replay()
else:
run_pack()
events[-1].record()
torch.cuda.synchronize()

# Print statistics
# Print before `cudaProfilerStop` to ensure messages are included in the profile
time_list = [
start.elapsed_time(stop) for start, stop in zip(events, events[1:])
]
time_list = time_list[warmup_times:]
print(f"[RANK {rank}]"
f" min {np.min(time_list) * 1000:.1f}"
f" max {np.max(time_list) * 1000:.1f}"
f" mean {np.mean(time_list) * 1000:.1f}"
f" median {np.median(time_list) * 1000:.1f}"
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
f" (us)")

torch.cuda.cudart().cudaProfilerStop()
37 changes: 37 additions & 0 deletions examples/layer_wise_benchmarks/run_single.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash

set -euo pipefail

if [ -v OMPI_COMM_WORLD_SIZE ]; then
export WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
export RANK=$OMPI_COMM_WORLD_RANK
export LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
export NODE_RANK=$OMPI_COMM_WORLD_NODE_RANK
fi

if [ "$RANK" -eq 0 ]; then
export TLLM_LOG_LEVEL=INFO
fi

PROFILE=${PROFILE:-1}
GPU_METRICS=${GPU_METRICS:-0}
if [ "$PROFILE" -eq 1 ]; then
PROFILE_FOLDER=profiles/run_single
mkdir -p ${PROFILE_FOLDER}
PROFILE_CMD="nsys profile
-t cuda,nvtx -s none
--cpuctxsw none --cuda-event-trace false
--cuda-graph-trace node
-c cudaProfilerApi --capture-range-end stop
-o ${PROFILE_FOLDER}/run_single_ep${WORLD_SIZE}_rank${RANK}.nsys-rep
--force-overwrite true"
if [ "$GPU_METRICS" -eq 1 ]; then
PROFILE_CMD+=" --gpu-metrics-devices $LOCAL_RANK
--gpu-metrics-frequency 10000"
fi
else
PROFILE_CMD=
fi

set -x
$PROFILE_CMD python3 -u run_single.py "$@"
20 changes: 20 additions & 0 deletions examples/layer_wise_benchmarks/slurm_alloc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

set -euo pipefail

# ACCOUNT=
# PARTITION=
# EXTRA_ARGS="--gres gpu:4"
TIME=${TIME:-01:00:00}

set -x
salloc -A "$ACCOUNT" \
-p "$PARTITION" \
-N "$NODES" \
--segment "$NODES" \
$EXTRA_ARGS \
-t "$TIME" \
--no-shell \
2>&1 \
| tee >(cat >&2) \
| awk '/Granted job allocation/ {print $NF}'
Loading