diff --git a/.gitignore b/.gitignore index c02a673aa93..2628489807f 100644 --- a/.gitignore +++ b/.gitignore @@ -82,4 +82,5 @@ compile_commands.json .devcontainer/docker-compose.override.yml # Enroot sqsh files +enroot/sw-tensorrt-docker+*.sqsh enroot/tensorrt_llm.devel.sqsh diff --git a/examples/layer_wise_benchmarks/README.md b/examples/layer_wise_benchmarks/README.md new file mode 100644 index 00000000000..5bfd6ecd441 --- /dev/null +++ b/examples/layer_wise_benchmarks/README.md @@ -0,0 +1,95 @@ +# Layer-wise Benchmarks + +## Generate profiles + +### Run with MPI + +**Step 1:** Start a container using Docker, Enroot or others. Please refer to `../../jenkins/current_image_tags.properties` for the Docker image URI. + +**Step 2:** In the container, install `tensorrt_llm`: + +```bash +pip install -e ../.. +``` + +**Step 3:** In the container, run benchmarks and generate profiles: + +```bash +# Run DeepSeek-R1 +NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml + +# Run DeepSeek-V3.2-Exp +NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM + +# Run DeepSeek-V3.2-Exp with 32k context length +NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --max-num-tokens $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769 + +# Run with attention TP +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp +NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp + +# Run with attention TP and TRTLLMGen +NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM +NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM + +# Run with MTP3 +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --batch-size 32 --seq-len-q 4 + +# Run 4 layers +NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --layer-indices 5,6,7,8 +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --layer-indices 5,6,7,8 + +# Scale DEP=16 MNNVL to 4 GPUs: reduce the number of experts, uses MNNVL A2A if applicable +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-backend WIDEEP + +# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts +NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp + +# Run with DeepEP A2A +NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP +NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP +``` + +### Run with Slurm + +> Tips: If you have a running job with environment installed, please skip step 1 and 2 and go straight to step 3. In this case, your job must be run with `--container-name aaa`, and if the container name is not "layer_wise_benchmarks" please `export CONTAINER_NAME=aaa`. + +**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`: + +```bash +SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh) +``` + +Please fill the variables in `./slurm_alloc.sh`. + +**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node: + +```bash +SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh +``` + +It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once. + +**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes: + +```bash +# Run DeepSeek-R1 with wide ep: uses MNNVL A2A if applicable +SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP + +# Run with attention TP and TRTLLMGen +SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM + +# Run with DeepEPLowLatency +SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP + +# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job +SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run_single.sh config_ctx.yaml +SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run_single.sh config_ctx.yaml +``` + +## Parse profiles + +Coming soon. diff --git a/examples/layer_wise_benchmarks/config_ctx.yaml b/examples/layer_wise_benchmarks/config_ctx.yaml new file mode 100644 index 00000000000..13a637e1624 --- /dev/null +++ b/examples/layer_wise_benchmarks/config_ctx.yaml @@ -0,0 +1,21 @@ +model: nvidia/DeepSeek-R1-0528-FP4-v2 +layer_indices: [5] +run_type: CTX +scaled_from: null + +# KV cache related args +tokens_per_block: 32 +max_seq_len: 9220 # 8192 + 1024 + 4 +enable_attention_dp: true + +# Model init args +max_num_tokens: 20480 +moe_backend: CUTLASS +use_cuda_graph: false + +# Per iteration args +batch_size: 1 +seq_len_q: 8193 +seq_len_kv_cache: 0 +balance_method: Balanced +balance_ratio: null diff --git a/examples/layer_wise_benchmarks/config_gen.yaml b/examples/layer_wise_benchmarks/config_gen.yaml new file mode 100644 index 00000000000..9ad86f8e594 --- /dev/null +++ b/examples/layer_wise_benchmarks/config_gen.yaml @@ -0,0 +1,21 @@ +model: nvidia/DeepSeek-R1-0528-FP4-v2 +layer_indices: [5] +run_type: GEN +scaled_from: null + +# KV cache related args +tokens_per_block: 32 +max_seq_len: 9220 # 8192 + 1024 + 4 +enable_attention_dp: true + +# Model init args +max_num_tokens: 4096 # MTP3 as max +moe_backend: CUTLASS +use_cuda_graph: true + +# Per iteration args +batch_size: 128 +seq_len_q: 1 # Set to (1 + MTP) +seq_len_kv_cache: 8193 +balance_method: Balanced +balance_ratio: null diff --git a/examples/layer_wise_benchmarks/mpi_launch.sh b/examples/layer_wise_benchmarks/mpi_launch.sh new file mode 100755 index 00000000000..24f7643bebe --- /dev/null +++ b/examples/layer_wise_benchmarks/mpi_launch.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -euo pipefail + +# Clear slurm envs +unset $(env | grep -i slurm | awk -F'=' '{print $1}') +unset $(env | grep MPI | awk -F'=' '{print $1}') + +set -x +mpirun --allow-run-as-root --np ${NP} "$@" diff --git a/examples/layer_wise_benchmarks/run_single.py b/examples/layer_wise_benchmarks/run_single.py new file mode 100644 index 00000000000..79d4bbe5019 --- /dev/null +++ b/examples/layer_wise_benchmarks/run_single.py @@ -0,0 +1,159 @@ +import argparse + +import numpy as np +import nvtx +import torch +import yaml + +from tensorrt_llm._torch.autotuner import AutoTuner, autotune +from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream +from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size +from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import ( + BalanceMethod, DeepSeekV3Runner) + + +def comma_separated_ints(s): + return [int(x) for x in s.split(",")] + + +# Parse cmdline +parser = argparse.ArgumentParser() +parser.add_argument("config_path", type=str) +parser.add_argument("--model", type=str, help="Pretrained model name or path") +parser.add_argument( + "--layer-indices", + type=comma_separated_ints, + help="Comma separated indices of layers, should be a contiguous range") +parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"]) +parser.add_argument("--scaled-from", type=int) +# KV cache related args +parser.add_argument("--tokens-per-block", type=int) +parser.add_argument("--max-seq-len", type=int) +group = parser.add_mutually_exclusive_group(required=False) +group.add_argument("--enable-attention-dp", + action="store_true", + dest="enable_attention_dp") +group.add_argument("--no-enable-attention-dp", + action="store_false", + dest="enable_attention_dp") +parser.set_defaults(enable_attention_dp=None) +# Model init args +parser.add_argument("--max-num-tokens", type=int) +parser.add_argument("--moe-backend", type=str) +group = parser.add_mutually_exclusive_group(required=False) +group.add_argument("--use-cuda-graph", + action="store_true", + dest="use_cuda_graph") +group.add_argument("--no-use-cuda-graph", + action="store_false", + dest="use_cuda_graph") +parser.set_defaults(use_cuda_graph=None) +# Per iteration args +parser.add_argument("--batch-size", type=int) +parser.add_argument("--seq-len-q", type=int) +parser.add_argument("--seq-len-kv-cache", type=int) +parser.add_argument("--balance-method", type=str) +parser.add_argument("--balance-ratio", type=float) +args = parser.parse_args() +with open(args.config_path) as f: + config = yaml.safe_load(f) +del args.config_path +for k, v in vars(args).items(): + if v is None: + setattr(args, k, config[k]) +print(args) + +# MPI args +rank = mpi_rank() +world_size = mpi_world_size() +local_rank = local_mpi_rank() +torch.cuda.set_device(local_rank) + +# Create KV cache manager +mapping = DeepSeekV3Runner.create_mapping( + enable_attention_dp=args.enable_attention_dp) +max_batch_size = 2048 +kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager( + args.model, + mapping, + tokens_per_block=args.tokens_per_block, + max_batch_size=max_batch_size, + max_seq_len=args.max_seq_len, + layer_indices=args.layer_indices) +attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8) + +# Create other global objects +AutoTuner.get().clear_cache() +capture_stream = torch.cuda.Stream() + +# Create Runner +runner = DeepSeekV3Runner(args.model, + mapping, + moe_backend=args.moe_backend, + layer_indices=args.layer_indices, + scaled_from=args.scaled_from, + max_seq_len=args.max_seq_len, + max_num_tokens=args.max_num_tokens, + use_cuda_graph=args.use_cuda_graph) + +# Warm up +assert args.batch_size <= max_batch_size +assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len +run_pack = runner.create_run_pack(args.run_type, + batch_size=args.batch_size, + seq_len_q=args.seq_len_q, + seq_len_kv_cache=args.seq_len_kv_cache, + kv_cache_manager=kv_cache_manager, + attn_workspace=attn_workspace) +runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method], + balance_ratio=args.balance_ratio) +capture_stream.wait_stream(torch.cuda.current_stream()) +with torch.cuda.stream(capture_stream): + run_pack() + with autotune(): + run_pack() +torch.cuda.current_stream().wait_stream(capture_stream) +torch.cuda.synchronize() + +# Profile: capture graph and replay it +torch.cuda.cudart().cudaProfilerStart() +if args.use_cuda_graph: + with with_multi_stream(True): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, + stream=capture_stream, + capture_error_mode="global"): + run_pack() + +warmup_times = 20 +run_times = 100 +events = [ + torch.cuda.Event(enable_timing=True) + for _ in range(warmup_times + run_times + 1) +] +for i in range(warmup_times + run_times): + events[i].record() + with nvtx.annotate( + f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"): + if args.use_cuda_graph: + g.replay() + else: + run_pack() +events[-1].record() +torch.cuda.synchronize() + +# Print statistics +# Print before `cudaProfilerStop` to ensure messages are included in the profile +time_list = [ + start.elapsed_time(stop) for start, stop in zip(events, events[1:]) +] +time_list = time_list[warmup_times:] +print(f"[RANK {rank}]" + f" min {np.min(time_list) * 1000:.1f}" + f" max {np.max(time_list) * 1000:.1f}" + f" mean {np.mean(time_list) * 1000:.1f}" + f" median {np.median(time_list) * 1000:.1f}" + f" P90 {np.percentile(time_list, 90) * 1000:.1f}" + f" (us)") + +torch.cuda.cudart().cudaProfilerStop() diff --git a/examples/layer_wise_benchmarks/run_single.sh b/examples/layer_wise_benchmarks/run_single.sh new file mode 100755 index 00000000000..be9aa6e5a4d --- /dev/null +++ b/examples/layer_wise_benchmarks/run_single.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +set -euo pipefail + +if [ -v OMPI_COMM_WORLD_SIZE ]; then + export WORLD_SIZE=$OMPI_COMM_WORLD_SIZE + export RANK=$OMPI_COMM_WORLD_RANK + export LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK + export NODE_RANK=$OMPI_COMM_WORLD_NODE_RANK +fi + +if [ "$RANK" -eq 0 ]; then + export TLLM_LOG_LEVEL=INFO +fi + +PROFILE=${PROFILE:-1} +GPU_METRICS=${GPU_METRICS:-0} +if [ "$PROFILE" -eq 1 ]; then + PROFILE_FOLDER=profiles/run_single + mkdir -p ${PROFILE_FOLDER} + PROFILE_CMD="nsys profile + -t cuda,nvtx -s none + --cpuctxsw none --cuda-event-trace false + --cuda-graph-trace node + -c cudaProfilerApi --capture-range-end stop + -o ${PROFILE_FOLDER}/run_single_ep${WORLD_SIZE}_rank${RANK}.nsys-rep + --force-overwrite true" + if [ "$GPU_METRICS" -eq 1 ]; then + PROFILE_CMD+=" --gpu-metrics-devices $LOCAL_RANK + --gpu-metrics-frequency 10000" + fi +else + PROFILE_CMD= +fi + +set -x +$PROFILE_CMD python3 -u run_single.py "$@" diff --git a/examples/layer_wise_benchmarks/slurm_alloc.sh b/examples/layer_wise_benchmarks/slurm_alloc.sh new file mode 100755 index 00000000000..cb25f57fc34 --- /dev/null +++ b/examples/layer_wise_benchmarks/slurm_alloc.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +set -euo pipefail + +# ACCOUNT= +# PARTITION= +# EXTRA_ARGS="--gres gpu:4" +TIME=${TIME:-01:00:00} + +set -x +salloc -A "$ACCOUNT" \ + -p "$PARTITION" \ + -N "$NODES" \ + --segment "$NODES" \ + $EXTRA_ARGS \ + -t "$TIME" \ + --no-shell \ + 2>&1 \ + | tee >(cat >&2) \ + | awk '/Granted job allocation/ {print $NF}' diff --git a/examples/layer_wise_benchmarks/slurm_init_containers.sh b/examples/layer_wise_benchmarks/slurm_init_containers.sh new file mode 100755 index 00000000000..06a8d818e21 --- /dev/null +++ b/examples/layer_wise_benchmarks/slurm_init_containers.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +set -euo pipefail + +# CONTAINER_IMAGE= +CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..") + +if [ "${SLURM_JOB_ID:-}" == "" ]; then + echo "Please set SLURM_JOB_ID" + exit 1 +fi + +NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D") + +if [ "${CONTAINER_IMAGE:-}" == "" ]; then + # Read Docker image from current_image_tags.properties + source ../../jenkins/current_image_tags.properties + MACHINE="$(uname -m)" + if [ "$MACHINE" == "x86_64" ]; then + DOCKER_IMAGE=$LLM_DOCKER_IMAGE + elif [ "$MACHINE" == "aarch64" ]; then + DOCKER_IMAGE=$LLM_SBSA_DOCKER_IMAGE + else + echo "Unsupported machine hardware name \"$MACHINE\"" + fi + + # Change "urm.nvidia.com/sw-tensorrt-docker/..." to "urm.nvidia.com#sw-tensorrt-docker/..." to bypass credentials + DOCKER_IMAGE="${DOCKER_IMAGE/\//#}" + echo "CONTAINER_IMAGE was not set, using Docker image $DOCKER_IMAGE" + + # Import to .sqsh file + SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" | + awk -F'#' '{print $2}' | + awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}') + CONTAINER_IMAGE="../../enroot/$SQSH_FILE_NAME" + if [ ! -f "$CONTAINER_IMAGE" ]; then + echo "Container image file $CONTAINER_IMAGE does not exist, importing ..." + srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE" + fi +fi + +WORKDIR=$(realpath "$(pwd)") + +set -x +srun -N "$NODES" \ + --ntasks-per-node 1 \ + --container-image "$CONTAINER_IMAGE" \ + --container-name "layer_wise_benchmarks" \ + --container-mounts "$CONTAINER_MOUNTS" \ + --container-workdir "$WORKDIR" \ +bash -c "pip install -U packaging && + pip install -r ../../requirements.txt --no-build-isolation && + pip install -e ../.." diff --git a/examples/layer_wise_benchmarks/slurm_launch.sh b/examples/layer_wise_benchmarks/slurm_launch.sh new file mode 100755 index 00000000000..36fe0a9a00f --- /dev/null +++ b/examples/layer_wise_benchmarks/slurm_launch.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -euo pipefail + +CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks} +CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..") + +if [ "${SLURM_JOB_ID:-}" == "" ]; then + echo "Please set SLURM_JOB_ID" + exit 1 +fi + +WORKDIR=$(realpath "$(pwd)") + +set -x +srun --mpi=pmix \ + -N "$NODES" \ + --ntasks-per-node $(($NP / $NODES)) \ + --container-name "$CONTAINER_NAME" \ + --container-mounts "$CONTAINER_MOUNTS" \ + --container-workdir "$WORKDIR" \ + "$@" diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py b/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py new file mode 100644 index 00000000000..c3fa4ebd054 --- /dev/null +++ b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py @@ -0,0 +1,413 @@ +import functools +import os +import weakref +from enum import IntEnum +from typing import List, Optional + +import torch + +import tensorrt_llm._torch.models.modeling_deepseekv3 +from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_deepseekv3 import ( + DeepseekV3DecoderLayer, DeepseekV3Gate) +from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE +from tensorrt_llm._torch.modules.linear import Linear, WeightMode +from tensorrt_llm._torch.modules.rms_norm import RMSNorm +from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._torch.utils import (AuxStreamType, get_model_extra_attrs, + model_extra_attrs) +from tensorrt_llm._utils import (local_mpi_size, mpi_rank, mpi_world_size, + torch_dtype_to_binding) +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.functional import AllReduceStrategy +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantConfig + + +class BalanceMethod(IntEnum): + NotModified = 1 + Balanced = 2 + ImbalancedRanks = 3 + ImbalancedExperts = 4 + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def round_up(a, b): + return ceil_div(a, b) * b + + +class RoutingMethod(DeepseekV3Gate): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.world_size = mpi_world_size() + self.rank = mpi_rank() + self.balance_method = None + self.balance_ratio = None + + def apply(self, router_logits) -> (torch.Tensor, torch.Tensor): + token_selected_experts, token_final_scales = super().apply( + router_logits) + num_experts = self.weight.shape[0] + if self.balance_method == BalanceMethod.NotModified: + pass + elif self.balance_method == BalanceMethod.Balanced: + token_selected_experts = RoutingMethod.get_balanced_selection( + token_selected_experts.shape[0], + token_selected_experts.shape[1], num_experts, + token_selected_experts.dtype, self.world_size, self.rank) + elif self.balance_method == BalanceMethod.ImbalancedRanks: + token_selected_experts = RoutingMethod.get_all_to_one_selection( + token_selected_experts.shape[0], + token_selected_experts.shape[1], num_experts, + self.balance_ratio, token_selected_experts.dtype, + self.world_size, self.rank) + elif self.balance_method == BalanceMethod.ImbalancedExperts: + token_selected_experts = RoutingMethod.get_balanced_rank_imbalanced_expert_selection( + token_selected_experts.shape[0], + token_selected_experts.shape[1], num_experts, + self.balance_ratio, token_selected_experts.dtype, + self.world_size, self.rank) + else: + raise NotImplementedError( + f"Not support balance_method {self.balance_method}") + return token_selected_experts, token_final_scales + + @functools.cache + @staticmethod + def get_balanced_selection(num_tokens, top_k, num_experts, dtype, + world_size, rank): + a = torch.arange(num_tokens * world_size * top_k, + dtype=dtype, + device="cuda").view(num_tokens, world_size, + top_k)[:, rank] + experts = (a * (num_experts // world_size + 1) + a // num_experts * + (num_experts // world_size)) % num_experts + return experts.contiguous() + + @staticmethod + def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, + world_size, rank): + num_tokens, top_k = imbalanced_experts.shape + dtype = imbalanced_experts.dtype + balanced_experts = RoutingMethod.get_balanced_selection( + num_tokens, top_k, num_experts, dtype, world_size, rank) + num_balanced_tokens = round(num_tokens * balance_ratio) + if balance_ratio != 0: + # Activate all experts + num_balanced_tokens = max(num_balanced_tokens, + ceil_div(num_experts, world_size * top_k)) + mixed_experts = balanced_experts.clone() + mixed_experts[num_balanced_tokens:] = imbalanced_experts[ + num_balanced_tokens:] + return mixed_experts + + @functools.cache + @staticmethod + def get_all_to_one_selection(num_tokens, top_k, num_experts, balance_ratio, + dtype, world_size, rank): + assert num_experts // world_size >= top_k + imbalanced_experts = torch.arange( + num_tokens * top_k, dtype=dtype, device="cuda").view( + num_tokens, top_k) % (num_experts // world_size) + return RoutingMethod.apply_balance_ratio(imbalanced_experts, + num_experts, balance_ratio, + world_size, rank) + + @functools.cache + @staticmethod + def get_balanced_rank_imbalanced_expert_selection(num_tokens, top_k, + num_experts, + balance_ratio, dtype, + world_size, rank): + experts_per_rank = num_experts // world_size + activate_experts_per_rank = ceil_div(top_k, world_size) + a = torch.arange(num_tokens * top_k, dtype=dtype, + device="cuda").view(num_tokens, top_k) + narrow_experts = a % (activate_experts_per_rank * world_size) + imbalanced_experts = narrow_experts * experts_per_rank % num_experts + narrow_experts // world_size % experts_per_rank + return RoutingMethod.apply_balance_ratio(imbalanced_experts, + num_experts, balance_ratio, + world_size, rank) + + +class DeepSeekV3Runner: + + def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, + moe_backend: str, layer_indices: List[int], + scaled_from: Optional[int], max_seq_len: int, + max_num_tokens: int, use_cuda_graph: bool): + + # Temporally replace the gate class + gate_cls_orig = tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate + tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = RoutingMethod + + self.model_config = ModelConfig.from_pretrained( + pretrained_model_name_or_path, + mapping=mapping, + enable_min_latency=False, + use_cuda_graph=use_cuda_graph, + force_dynamic_quantization=False, + spec_config=None, + sparse_attention_config=None, # To be loaded from config + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + moe_max_num_tokens=None, + moe_load_balancer=None, + lora_config=None, + allreduce_strategy=AllReduceStrategy.AUTO, + mm_encoder_only=False, + attn_backend="TRTLLM", + moe_backend=moe_backend, + moe_disable_finalize_fusion=False, + use_low_precision_moe_combine=False, + skip_create_weights_in_init=True, + ) + + pretrained_config = self.model_config.pretrained_config + if scaled_from is not None: + # To run the problem size of $B$ GPUs on $A$ GPUs, we need: + # (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change. + # (2) MoE: If EP, reduce the number of experts; If TP, reduce head size. + # Maintain the result of AllToAll method selection because it is affected by EP size. + if not mapping.enable_attention_dp: + if hasattr(pretrained_config, "index_n_heads"): + raise NotImplementedError( + "Not support Indexer TP for weak scaling") + pretrained_config.num_attention_heads = pretrained_config.num_attention_heads // scaled_from * mapping.tp_size + pretrained_config.num_key_value_heads = pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size + if mapping.moe_ep_size != mapping.world_size: + raise NotImplementedError("Not support MoE TP for weak scaling") + pretrained_config.n_routed_experts = pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size + select_alltoall_method_type_orig = WideEPMoE.select_alltoall_method_type + + def select_alltoall_method_type(cls: type, mapping: Mapping, + top_k: int, *args, **kwargs): + # Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k` + # by replacing `top_k` with `fake_top_k` + if scaled_from <= top_k: + fake_top_k = mapping.moe_ep_size + 1 + else: + fake_top_k = mapping.moe_ep_size - 1 + assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from + <= top_k) + return select_alltoall_method_type_orig(mapping, fake_top_k, + *args, **kwargs) + + WideEPMoE.select_alltoall_method_type = select_alltoall_method_type + + aux_stream_list = [torch.cuda.Stream() for _ in range(2)] + aux_stream_dict = { + AuxStreamType.Attention: aux_stream_list[0], + AuxStreamType.MoeShared: aux_stream_list[0], + AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], + } + + layers = [ + DeepseekV3DecoderLayer( + model_config=self.model_config, + layer_idx=layer_idx, + aux_stream_dict=aux_stream_dict, + ) for layer_idx in layer_indices + ] + next_layer_layernorm = RMSNorm( + hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype) + + # apply_quant_config_exclude_modules + # Please refer to tensorrt_llm/_torch/models/modeling_utils.py + quant_config = self.model_config.quant_config + new_quant_config = QuantConfig( + kv_cache_quant_algo=quant_config.kv_cache_quant_algo) + for layer in layers: + for name, module in layer.named_modules(): + name = f"model.layers.{layer.layer_idx}.{name}" + candidates = [name] + if isinstance(module, Linear): + weight_mode = module.weights_loading_config.weight_mode + if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR: + # sometimes gate and up proj are not packed in the checkpoint, + # but they still share the same exclusion rule + candidates += [ + name.replace('gate_up_proj', 'gate_proj'), + name.replace('gate_up_proj', 'up_proj') + ] + elif weight_mode == WeightMode.FUSED_QKV_LINEAR: + # sometimes q_proj, k_proj and v_proj are not packed in the checkpoint, + # but they still share the same exclusion rule + candidates += [ + name.replace('qkv_proj', 'q_proj'), + name.replace('qkv_proj', 'k_proj'), + name.replace('qkv_proj', 'v_proj') + ] + is_excluded = any( + quant_config.is_module_excluded_from_quantization(n) + for n in candidates) + if is_excluded and getattr(module, "quant_config", + None) is not None: + module.quant_config = new_quant_config + for name, module in layer.named_modules(): + if callable(getattr(module, "create_weights", None)): + module.create_weights() + layer.cuda() + for name, module in layer.named_modules(): + if hasattr(module, 'post_load_weights') and not getattr( + module, '_weights_removed', False): + module.post_load_weights() + next_layer_layernorm.cuda() + for layer, next_layer in zip(layers[:-1], layers[1:]): + layer.next_layer_layernorm = next_layer.input_layernorm + layers[-1].next_layer_layernorm = next_layer_layernorm + + self.layers = layers + if scaled_from is not None: + WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_orig + tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = gate_cls_orig + + def create_run_pack(self, + run_type: str, + batch_size: int, + seq_len_q: int, + seq_len_kv_cache: int, + kv_cache_manager: Optional[KVCacheManager] = None, + attn_workspace: Optional[torch.Tensor] = None): + if self.model_config.moe_backend == "TRTLLM" and os.getenv( + "TRTLLM_ENABLE_PDL") != "1": + raise ValueError( + "Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM") + world_size = mpi_world_size() + AttentionCls = get_attention_backend( + self.model_config.attn_backend, + self.model_config.sparse_attention_config) + attn_metadata = AttentionCls.Metadata( + seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int), + request_ids=list(range(batch_size)), + max_num_requests=kv_cache_manager.max_batch_size, + num_contexts={ + "CTX": batch_size, + "GEN": 0 + }[run_type], + prompt_lens=[{ + "CTX": seq_len_q, + "GEN": seq_len_kv_cache + }[run_type]] * batch_size, + max_num_tokens=batch_size * seq_len_q, + kv_cache_manager=kv_cache_manager, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=[seq_len_kv_cache] * batch_size, + ), + workspace=attn_workspace, + mapping=self.model_config.mapping, + sparse_attention_config=self.model_config.sparse_attention_config, + ) + attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q + ] * world_size + attn_metadata.prepare() + with model_extra_attrs(self.model_config.extra_attrs): + get_model_extra_attrs()["attention_metadata"] = weakref.ref( + attn_metadata) + hidden_size = self.model_config.pretrained_config.hidden_size + position_ids = torch.tensor([ + list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * + batch_size + ], + dtype=torch.int32, + device="cuda") + hidden_states = torch.rand((batch_size * seq_len_q, hidden_size), + dtype=torch.bfloat16, + device="cuda") + residual = torch.rand((batch_size * seq_len_q, hidden_size), + dtype=torch.bfloat16, + device="cuda") + + def run_pack(): + output = hidden_states, residual + with model_extra_attrs(self.model_config.extra_attrs): + with torch.inference_mode(): + for layer in self.layers: + output = layer(position_ids, output[0], attn_metadata, + output[1]) + return output + + return run_pack + + def replace_routing_method(self, balance_method: BalanceMethod, + balance_ratio: float): + if self.model_config.moe_backend not in [ + "CUTLASS", "DEEPGEMM", "TRTLLM", "WIDEEP" + ]: + raise NotImplementedError( + f"Not support replace routing method for moe_backend \"{self.model_config.moe_backend}\"," + f" please set balance_method to \"NotModified\"") + for layer in self.layers: + layer.mlp.gate.balance_method = balance_method + layer.mlp.gate.balance_ratio = balance_ratio + + @staticmethod + def create_kv_cache_manager(pretrained_model_name_or_path, mapping, + tokens_per_block, max_batch_size, max_seq_len, + layer_indices): + # Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block` + model_config = ModelConfig.from_pretrained( + pretrained_model_name_or_path) + if model_config.enable_flash_mla: + assert tokens_per_block == 64 + + # Please refer to `tensorrt_llm/_torch/pyexecutor/_util.py` for `kv_cache_manager` + kv_cache_manager_cls = get_kv_cache_manager_cls(model_config) + kv_cache_manager = kv_cache_manager_cls( + KvCacheConfig( + max_tokens=max_batch_size * + round_up(max_seq_len, tokens_per_block), + enable_block_reuse=False, + ), + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY, + num_layers=len(layer_indices), + num_kv_heads=1, + head_dim=model_config.pretrained_config.kv_lora_rank + + model_config.pretrained_config.qk_rope_head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=torch_dtype_to_binding({ + None: torch.bfloat16, + "FP8": torch.float8_e4m3fn + }[model_config.quant_config.kv_cache_quant_algo]), + sparse_attn_config=model_config.sparse_attention_config, + ) + kv_cache_manager.layer_offsets = { + layer_idx: i + for i, layer_idx in enumerate(layer_indices) + } + kv_cache_manager.add_dummy_requests(list(range(max_batch_size)), + [max_seq_len] * max_batch_size) + return kv_cache_manager + + @staticmethod + def create_mapping(enable_attention_dp: bool): + world_size = mpi_world_size() + rank = mpi_rank() + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=local_mpi_size(), + cp_size=1, + tp_size=world_size, + pp_size=1, + moe_cluster_size=1, + moe_tp_size=1, + moe_ep_size=world_size, + attn_tp_size=world_size, + attn_cp_size=1, + enable_attention_dp=enable_attention_dp, + ) + return mapping diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 0e216d4acce..5b11740ad4c 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -18,6 +18,8 @@ l0_dgx_b200: - unittest/_torch/multi_gpu_modeling -k "deepseek" - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL] + - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_tep + - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_gen_scaled_from_16_dep - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] diff --git a/tests/unittest/tools/test_layer_wise_benchmarks.py b/tests/unittest/tools/test_layer_wise_benchmarks.py new file mode 100644 index 00000000000..ee282cf614a --- /dev/null +++ b/tests/unittest/tools/test_layer_wise_benchmarks.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch +from defs.trt_test_alternative import check_call +from utils.cpp_paths import llm_root # noqa: F401 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +def test_deepseek_r1_ctx_tep(llm_root): + check_call([ + "./mpi_launch.sh", + "./run_single.sh", + "config_ctx.yaml", + "--no-enable-attention-dp", + "--moe-backend=TRTLLM", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": "4", + "TRTLLM_ENABLE_PDL": "1", + }) + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +def test_deepseek_v32_ctx_dep(llm_root): + check_call([ + "./mpi_launch.sh", + "./run_single.sh", + "config_ctx.yaml", + "--model=deepseek-ai/DeepSeek-V3.2-Exp", + "--tokens-per-block=64", + "--moe-backend=DEEPGEMM", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": "4", + }) + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +def test_deepseek_r1_gen_scaled_from_16_dep(llm_root): + check_call([ + "./mpi_launch.sh", + "./run_single.sh", + "config_gen.yaml", + "--scaled-from=16", + "--moe-backend=WIDEEP", + "--layer-indices=5,6", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": "4", + })