-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Add layer wise benchmarks #8777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ed497f0
Add layer-wise benchmarks
yuantailing cd67da6
Print the time
yuantailing d7a446a
Add context test to benchmark
yuantailing 12210e4
Update config.json
yuantailing 8f3b6dc
Add slurm scripts
yuantailing 4d4e427
Import Docker image, set default CONTAINER_MOUNTS
yuantailing f68b0a9
Polish the scripts
yuantailing f8184e4
Refine README and scripts
yuantailing e2fb753
Fix kv cache size
yuantailing f0c7245
Make configs more general
yuantailing c0c57ad
Use yaml instead of hardcode
yuantailing fcae76f
Use oneline models
yuantailing 15a9f1a
Refactor the configs
yuantailing f32371c
Remove --max-batch-size option
yuantailing 7bb476f
Add --scaled-from and update examples
yuantailing 791bfbc
Add to CI
yuantailing 84f5eee
Add --tokens-per-block
yuantailing File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # Layer-wise Benchmarks | ||
|
|
||
| ## Generate profiles | ||
|
|
||
| ### Run with MPI | ||
|
|
||
| **Step 1:** Start a container using Docker, Enroot or others. Please refer to `../../jenkins/current_image_tags.properties` for the Docker image URI. | ||
|
|
||
| **Step 2:** In the container, install `tensorrt_llm`: | ||
|
|
||
| ```bash | ||
| pip install -e ../.. | ||
| ``` | ||
|
|
||
| **Step 3:** In the container, run benchmarks and generate profiles: | ||
|
|
||
| ```bash | ||
| # Run DeepSeek-R1 | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml | ||
|
|
||
| # Run DeepSeek-V3.2-Exp | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM | ||
|
|
||
| # Run DeepSeek-V3.2-Exp with 32k context length | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --max-num-tokens $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769 | ||
|
|
||
| # Run with attention TP | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp | ||
|
|
||
| # Run with attention TP and TRTLLMGen | ||
| NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM | ||
| NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM | ||
|
|
||
| # Run with MTP3 | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --batch-size 32 --seq-len-q 4 | ||
|
|
||
| # Run 4 layers | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --layer-indices 5,6,7,8 | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --layer-indices 5,6,7,8 | ||
|
|
||
| # Scale DEP=16 MNNVL to 4 GPUs: reduce the number of experts, uses MNNVL A2A if applicable | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-backend WIDEEP | ||
|
|
||
| # Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts | ||
| NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp | ||
|
|
||
| # Run with DeepEP A2A | ||
| NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP | ||
| NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP | ||
| ``` | ||
|
|
||
| ### Run with Slurm | ||
|
|
||
| > Tips: If you have a running job with environment installed, please skip step 1 and 2 and go straight to step 3. In this case, your job must be run with `--container-name aaa`, and if the container name is not "layer_wise_benchmarks" please `export CONTAINER_NAME=aaa`. | ||
|
|
||
| **Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`: | ||
|
|
||
| ```bash | ||
| SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh) | ||
| ``` | ||
|
|
||
| Please fill the variables in `./slurm_alloc.sh`. | ||
|
|
||
| **Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node: | ||
|
|
||
| ```bash | ||
| SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh | ||
| ``` | ||
|
|
||
| It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once. | ||
|
|
||
| **Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes: | ||
|
|
||
| ```bash | ||
| # Run DeepSeek-R1 with wide ep: uses MNNVL A2A if applicable | ||
| SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP | ||
|
|
||
| # Run with attention TP and TRTLLMGen | ||
| SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM | ||
|
|
||
| # Run with DeepEPLowLatency | ||
| SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP | ||
|
|
||
| # You can run 4-GPU and 8-GPU tasks without reallocate the slurm job | ||
| SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run_single.sh config_ctx.yaml | ||
| SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run_single.sh config_ctx.yaml | ||
| ``` | ||
|
|
||
| ## Parse profiles | ||
|
|
||
| Coming soon. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| model: nvidia/DeepSeek-R1-0528-FP4-v2 | ||
| layer_indices: [5] | ||
| run_type: CTX | ||
| scaled_from: null | ||
|
|
||
| # KV cache related args | ||
| tokens_per_block: 32 | ||
| max_seq_len: 9220 # 8192 + 1024 + 4 | ||
| enable_attention_dp: true | ||
|
|
||
| # Model init args | ||
| max_num_tokens: 20480 | ||
| moe_backend: CUTLASS | ||
| use_cuda_graph: false | ||
|
|
||
| # Per iteration args | ||
| batch_size: 1 | ||
| seq_len_q: 8193 | ||
| seq_len_kv_cache: 0 | ||
| balance_method: Balanced | ||
| balance_ratio: null |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| model: nvidia/DeepSeek-R1-0528-FP4-v2 | ||
| layer_indices: [5] | ||
| run_type: GEN | ||
| scaled_from: null | ||
|
|
||
| # KV cache related args | ||
| tokens_per_block: 32 | ||
| max_seq_len: 9220 # 8192 + 1024 + 4 | ||
| enable_attention_dp: true | ||
|
|
||
| # Model init args | ||
| max_num_tokens: 4096 # MTP3 as max | ||
| moe_backend: CUTLASS | ||
| use_cuda_graph: true | ||
|
|
||
| # Per iteration args | ||
| batch_size: 128 | ||
| seq_len_q: 1 # Set to (1 + MTP) | ||
| seq_len_kv_cache: 8193 | ||
| balance_method: Balanced | ||
| balance_ratio: null |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| #!/bin/bash | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| # Clear slurm envs | ||
| unset $(env | grep -i slurm | awk -F'=' '{print $1}') | ||
| unset $(env | grep MPI | awk -F'=' '{print $1}') | ||
|
|
||
| set -x | ||
| mpirun --allow-run-as-root --np ${NP} "$@" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import argparse | ||
|
|
||
| import numpy as np | ||
| import nvtx | ||
| import torch | ||
| import yaml | ||
|
|
||
| from tensorrt_llm._torch.autotuner import AutoTuner, autotune | ||
| from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream | ||
| from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size | ||
| from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import ( | ||
| BalanceMethod, DeepSeekV3Runner) | ||
|
|
||
|
|
||
| def comma_separated_ints(s): | ||
| return [int(x) for x in s.split(",")] | ||
|
|
||
|
|
||
| # Parse cmdline | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("config_path", type=str) | ||
| parser.add_argument("--model", type=str, help="Pretrained model name or path") | ||
| parser.add_argument( | ||
| "--layer-indices", | ||
| type=comma_separated_ints, | ||
| help="Comma separated indices of layers, should be a contiguous range") | ||
| parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"]) | ||
| parser.add_argument("--scaled-from", type=int) | ||
| # KV cache related args | ||
| parser.add_argument("--tokens-per-block", type=int) | ||
| parser.add_argument("--max-seq-len", type=int) | ||
| group = parser.add_mutually_exclusive_group(required=False) | ||
| group.add_argument("--enable-attention-dp", | ||
| action="store_true", | ||
| dest="enable_attention_dp") | ||
| group.add_argument("--no-enable-attention-dp", | ||
| action="store_false", | ||
| dest="enable_attention_dp") | ||
| parser.set_defaults(enable_attention_dp=None) | ||
| # Model init args | ||
| parser.add_argument("--max-num-tokens", type=int) | ||
| parser.add_argument("--moe-backend", type=str) | ||
| group = parser.add_mutually_exclusive_group(required=False) | ||
| group.add_argument("--use-cuda-graph", | ||
| action="store_true", | ||
| dest="use_cuda_graph") | ||
| group.add_argument("--no-use-cuda-graph", | ||
| action="store_false", | ||
| dest="use_cuda_graph") | ||
| parser.set_defaults(use_cuda_graph=None) | ||
| # Per iteration args | ||
| parser.add_argument("--batch-size", type=int) | ||
| parser.add_argument("--seq-len-q", type=int) | ||
| parser.add_argument("--seq-len-kv-cache", type=int) | ||
| parser.add_argument("--balance-method", type=str) | ||
| parser.add_argument("--balance-ratio", type=float) | ||
| args = parser.parse_args() | ||
| with open(args.config_path) as f: | ||
| config = yaml.safe_load(f) | ||
| del args.config_path | ||
| for k, v in vars(args).items(): | ||
| if v is None: | ||
| setattr(args, k, config[k]) | ||
| print(args) | ||
|
|
||
| # MPI args | ||
| rank = mpi_rank() | ||
| world_size = mpi_world_size() | ||
| local_rank = local_mpi_rank() | ||
| torch.cuda.set_device(local_rank) | ||
|
|
||
| # Create KV cache manager | ||
| mapping = DeepSeekV3Runner.create_mapping( | ||
| enable_attention_dp=args.enable_attention_dp) | ||
| max_batch_size = 2048 | ||
| kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager( | ||
| args.model, | ||
| mapping, | ||
| tokens_per_block=args.tokens_per_block, | ||
| max_batch_size=max_batch_size, | ||
| max_seq_len=args.max_seq_len, | ||
| layer_indices=args.layer_indices) | ||
| attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8) | ||
|
|
||
| # Create other global objects | ||
| AutoTuner.get().clear_cache() | ||
| capture_stream = torch.cuda.Stream() | ||
|
|
||
| # Create Runner | ||
| runner = DeepSeekV3Runner(args.model, | ||
| mapping, | ||
| moe_backend=args.moe_backend, | ||
| layer_indices=args.layer_indices, | ||
| scaled_from=args.scaled_from, | ||
| max_seq_len=args.max_seq_len, | ||
| max_num_tokens=args.max_num_tokens, | ||
| use_cuda_graph=args.use_cuda_graph) | ||
|
|
||
| # Warm up | ||
| assert args.batch_size <= max_batch_size | ||
| assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len | ||
| run_pack = runner.create_run_pack(args.run_type, | ||
| batch_size=args.batch_size, | ||
| seq_len_q=args.seq_len_q, | ||
| seq_len_kv_cache=args.seq_len_kv_cache, | ||
| kv_cache_manager=kv_cache_manager, | ||
| attn_workspace=attn_workspace) | ||
| runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method], | ||
| balance_ratio=args.balance_ratio) | ||
| capture_stream.wait_stream(torch.cuda.current_stream()) | ||
| with torch.cuda.stream(capture_stream): | ||
| run_pack() | ||
| with autotune(): | ||
| run_pack() | ||
| torch.cuda.current_stream().wait_stream(capture_stream) | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Profile: capture graph and replay it | ||
| torch.cuda.cudart().cudaProfilerStart() | ||
| if args.use_cuda_graph: | ||
| with with_multi_stream(True): | ||
| g = torch.cuda.CUDAGraph() | ||
| with torch.cuda.graph(g, | ||
| stream=capture_stream, | ||
| capture_error_mode="global"): | ||
| run_pack() | ||
|
|
||
| warmup_times = 20 | ||
| run_times = 100 | ||
| events = [ | ||
| torch.cuda.Event(enable_timing=True) | ||
| for _ in range(warmup_times + run_times + 1) | ||
| ] | ||
| for i in range(warmup_times + run_times): | ||
| events[i].record() | ||
| with nvtx.annotate( | ||
| f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"): | ||
| if args.use_cuda_graph: | ||
| g.replay() | ||
| else: | ||
| run_pack() | ||
| events[-1].record() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Print statistics | ||
| # Print before `cudaProfilerStop` to ensure messages are included in the profile | ||
| time_list = [ | ||
| start.elapsed_time(stop) for start, stop in zip(events, events[1:]) | ||
| ] | ||
| time_list = time_list[warmup_times:] | ||
| print(f"[RANK {rank}]" | ||
| f" min {np.min(time_list) * 1000:.1f}" | ||
| f" max {np.max(time_list) * 1000:.1f}" | ||
| f" mean {np.mean(time_list) * 1000:.1f}" | ||
| f" median {np.median(time_list) * 1000:.1f}" | ||
| f" P90 {np.percentile(time_list, 90) * 1000:.1f}" | ||
| f" (us)") | ||
|
|
||
| torch.cuda.cudart().cudaProfilerStop() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| #!/bin/bash | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| if [ -v OMPI_COMM_WORLD_SIZE ]; then | ||
| export WORLD_SIZE=$OMPI_COMM_WORLD_SIZE | ||
| export RANK=$OMPI_COMM_WORLD_RANK | ||
| export LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK | ||
| export NODE_RANK=$OMPI_COMM_WORLD_NODE_RANK | ||
| fi | ||
|
|
||
| if [ "$RANK" -eq 0 ]; then | ||
| export TLLM_LOG_LEVEL=INFO | ||
| fi | ||
|
|
||
| PROFILE=${PROFILE:-1} | ||
| GPU_METRICS=${GPU_METRICS:-0} | ||
| if [ "$PROFILE" -eq 1 ]; then | ||
| PROFILE_FOLDER=profiles/run_single | ||
| mkdir -p ${PROFILE_FOLDER} | ||
| PROFILE_CMD="nsys profile | ||
| -t cuda,nvtx -s none | ||
| --cpuctxsw none --cuda-event-trace false | ||
| --cuda-graph-trace node | ||
| -c cudaProfilerApi --capture-range-end stop | ||
| -o ${PROFILE_FOLDER}/run_single_ep${WORLD_SIZE}_rank${RANK}.nsys-rep | ||
| --force-overwrite true" | ||
| if [ "$GPU_METRICS" -eq 1 ]; then | ||
| PROFILE_CMD+=" --gpu-metrics-devices $LOCAL_RANK | ||
| --gpu-metrics-frequency 10000" | ||
| fi | ||
| else | ||
| PROFILE_CMD= | ||
| fi | ||
|
|
||
| set -x | ||
| $PROFILE_CMD python3 -u run_single.py "$@" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| #!/bin/bash | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| # ACCOUNT= | ||
| # PARTITION= | ||
| # EXTRA_ARGS="--gres gpu:4" | ||
| TIME=${TIME:-01:00:00} | ||
|
|
||
| set -x | ||
| salloc -A "$ACCOUNT" \ | ||
| -p "$PARTITION" \ | ||
| -N "$NODES" \ | ||
| --segment "$NODES" \ | ||
| $EXTRA_ARGS \ | ||
| -t "$TIME" \ | ||
juney-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| --no-shell \ | ||
| 2>&1 \ | ||
| | tee >(cat >&2) \ | ||
| | awk '/Granted job allocation/ {print $NF}' | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.