diff --git a/examples/layer_wise_benchmarks/README.md b/examples/layer_wise_benchmarks/README.md index af2663c9aa2..426dee02fd8 100644 --- a/examples/layer_wise_benchmarks/README.md +++ b/examples/layer_wise_benchmarks/README.md @@ -15,28 +15,29 @@ pip install -e ../.. **Step 3:** In the container, run benchmarks and generate profiles: ```bash -# Set autotune cache path -export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache - # Run DeepSeek-R1 NVFP4 NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml +# Run with weights loaded. Requires local model directory +NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model "$LLM_MODELS_ROOT/DeepSeek-R1/DeepSeek-R1-0528-FP4-v2" --load-format AUTO +NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model "$LLM_MODELS_ROOT/DeepSeek-R1/DeepSeek-R1-0528-FP4-v2" --load-format AUTO + # Run DeepSeek-V3.2-Exp NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM -NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM +NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --moe-backend-for-prefill DEEPGEMM # Run DeepSeek-V3.2-Exp with 32k context length NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 -NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --seq-len-kv-cache 32769 +NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --moe-backend-for-prefill DEEPGEMM --seq-len-kv-cache 32769 # Run with attention TP NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp # Run with attention TP and TRTLLMGen -NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM -NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM +NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM +NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM # Run with MTP3 NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4 @@ -51,9 +52,13 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID # Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp +# Run Nemotron-3-Nano +NP=1 ./mpi_launch.sh ./run.sh config_ctx.yaml --model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 --layer-indices 4,5,6 --mamba-ssm-cache-dtype float16 +NP=1 ./mpi_launch.sh ./run.sh config_gen.yaml --model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 --layer-indices 4,5,6 --mamba-ssm-cache-dtype float16 + # Run Qwen3-Next -NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4 -NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512 +NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --mamba-ssm-cache-dtype float16 --batch-size 4 +NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --mamba-ssm-cache-dtype float16 --batch-size 512 # Run with DeepEP A2A NP=4 ./mpi_launch.sh -x TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./run.sh config_ctx.yaml --moe-backend WIDEEP @@ -112,14 +117,11 @@ python3 scripts/build_wheel.py --cuda_architectures native --no-venv --skip_buil **Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes: ```bash -# Set autotune cache path -export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache - # Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP # Run with TRTLLMGen -NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM +NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM # Run with DeepEPLowLatency NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP @@ -172,7 +174,9 @@ You will receive three reports, each containing kernel timing statistics grouped ## Developer utilities 1. Less startup time when debug a model - 1. Disable autotuner: add `--no-enable-autotuner` option + 1. Set autotuner cache or disable autotuner + 1. Set autotuner cache: add `TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache` environment variable. This is enabled at your own risk, and you may need to delete the cache if `NP` changes or the code changes + 2. Disable autotuner: add `--no-enable-autotuner` option 2. Disable nsys profile: set `PROFILE=0` environment variable 2. Capture more information 1. Enable GPU metrics: set `GPU_METRICS=1` environment variable @@ -182,4 +186,8 @@ You will receive three reports, each containing kernel timing statistics grouped 1. Error `fp8 blockscale gemm only support Hopper` on Blackwell. - The default MoE backend "CUTLASS" does not support FP8 weights. Please choose the same MoE backend as your end-to-end config. A typical choice is adding `--moe-backend DEEPGEMM`, `--moe-backend TRTLLM`, or `--moe-backend WIDEEP` option. + The default MoE backend "CUTLASS" does not support FP8 weights. Please choose the same MoE backend as your end-to-end config. A typical choice is adding `--moe-backend DEEPGEMM` (or `TRTLLM`, `WIDEEP`) and `--moe-backend-for-prefill DEEPGEMM` (or `WIDEEP`) option. + +2. Error `huggingface_hub.errors.HfHubHTTPError: 429 Client Error: Too Many Requests for url: https://huggingface.co/nvidia/DeepSeek-R1-0528-FP4-v2/resolve/main/config.json`. + + Please use a local model through the `--model` option, or follow Hugging Face's instructions: "We had to rate limit your IP. To continue using our service, create a HF account or login to your existing account, and make sure you pass a HF_TOKEN if you're using the API." diff --git a/examples/layer_wise_benchmarks/parse.py b/examples/layer_wise_benchmarks/parse.py index 3ebd4799225..c878574da63 100644 --- a/examples/layer_wise_benchmarks/parse.py +++ b/examples/layer_wise_benchmarks/parse.py @@ -6,7 +6,6 @@ import sqlite3 import subprocess import sys -from collections import defaultdict from pathlib import Path import jinja2 @@ -139,7 +138,7 @@ def shortest_common_supersequence(a, b): "runs": [], "runs_end": [], "ranges": [], - "range_in_module": [], + "kernel_count_per_range": [], } ) @@ -161,28 +160,7 @@ def shortest_common_supersequence(a, b): problem_set[problem_id]["runs_end"].append(end) else: problem_set[problem_id]["ranges"].append((start, end, text)) - -# Determine whether each range is the first range that matches `args.module`, -# and store the result in `problem["range_in_module"]` -for problem in problem_set: - if args.module is not None: - problem["range_in_module"] = [False] * len(problem["ranges"]) - run_ids = [bisect.bisect(problem["runs"], start) - 1 for start, _, _ in problem["ranges"]] - run2ranges = defaultdict(list) - for i, run_id in enumerate(run_ids): - run2ranges[run_id].append(i) - for run_id, ranges in run2ranges.items(): - ranges = sorted(ranges, key=lambda i: problem["ranges"][i][0]) - num_matches = 0 - for range_id in ranges: - if problem["ranges"][range_id][2] == args.module: - problem["range_in_module"][range_id] = True - num_matches += 1 - if num_matches != 1: - raise ValueError( - f'Module "{args.module}" appears {num_matches} times' - f' in "{problem["text"]}"\'s {run_id + 1}-th run' - ) + problem_set[problem_id]["kernel_count_per_range"].append(0) query = """SELECT name FROM sqlite_master WHERE type = ?""" df = pd.read_sql_query(query, conn, params=("table",)) @@ -228,19 +206,17 @@ def shortest_common_supersequence(a, b): problem_id = bisect.bisect(problem_start, start) - 1 problem = problem_set[problem_id] run_id = bisect.bisect(problem["runs"], runtime_start) - 1 - if ( - run_id == -1 - or run_id == len(problem["runs"]) - or runtime_start >= problem["runs_end"][run_id] - ): - run_id = -1 + if run_id == -1 or runtime_start >= problem["runs_end"][run_id]: + continue ranges = [ i for i, (range_start, range_end, text) in enumerate(problem["ranges"]) if capture_start >= range_start and capture_end <= range_end ] - if args.module is None or any(problem["range_in_module"][i] for i in ranges): - range_names = [problem["ranges"][i][2] for i in ranges] + for range_id in ranges: + problem["kernel_count_per_range"][range_id] += 1 + range_names = [problem["ranges"][i][2] for i in ranges] + if args.module is None or args.module in range_names: kernel_list.append( ( problem_id, @@ -262,6 +238,22 @@ def shortest_common_supersequence(a, b): conn.close() +# Check ambiguous modules +if args.module: + for problem in problem_set: + num_matches_per_run = [0] * (len(problem["runs"]) + 1) + for (range_start, _, text), kernel_count in zip( + problem["ranges"], problem["kernel_count_per_range"] + ): + if text == args.module and kernel_count > 0: + num_matches_per_run[bisect.bisect(problem["runs"], range_start)] += 1 + for run_id_plus_one, num_matches in enumerate(num_matches_per_run): + if num_matches > 1: + raise ValueError( + f'Module is ambiguous: "{args.module}" appears {num_matches} times' + f' in "{problem["text"]}"\'s {run_id_plus_one}-th run' + ) + kernel_list.sort(key=lambda t: (t[6], t[8])) kernels = [[[] for _ in problem["runs"]] for problem in problem_set] for ( @@ -276,8 +268,7 @@ def shortest_common_supersequence(a, b): capture_start, capture_end, ) in kernel_list: - if run_id != -1: - kernels[problem_id][run_id].append((demangledName, start, end, ranges)) + kernels[problem_id][run_id].append((demangledName, start, end, ranges)) for problem_id in range(len(kernels)): required_seq = [demangledName for demangledName, _, _, _ in kernels[problem_id][0]] for run_id in range(len(kernels[problem_id])): @@ -287,86 +278,8 @@ def shortest_common_supersequence(a, b): parser_keywords = [ ("cuBLASGemm", "nvjet"), - ("splitKreduce", "splitKreduce_kernel"), - ("fusedAGemm", "fused_a_gemm_kernel"), - ("RMSNorm", "RMSNormKernel"), - ("torchCat", "CatArrayBatchedCopy"), - ("applyMLARope", "applyMLARope"), - ("fmhaSm100f", "fmhaSm100fKernel_Qkv"), - ("fmhaReduction", "fmhaReductionKernel"), - ("quant", "quantize_with_block_size"), - ("AllGather", "ncclDevKernel_AllGather_"), - ("ReduceScatter", "ncclDevKernel_ReduceScatter_"), - ("allreduce_oneshot", "allreduce_fusion_kernel_oneshot_lamport"), - ("allreduce_twoshot", "allreduce_fusion_kernel_twoshot_sync"), - ("expandInput", "expandInputRowsKernel"), - ("computeStrides", "computeStridesTmaWarpSpecializedKernel"), ("cutlassGroupGemm", "cutlass::device_kernel", "at::native::CUDAFunctorOnSelf_add"), - ("convert_req_index", "_convert_req_index_to_global_index_kernel_with_stride_factor"), - ("preprocess_after_permute", "_preprocess_after_permute_kernel"), - ("masked_index_copy_quant", "_masked_index_copy_group_quant_fp8"), - ("swiglu_quant", "_silu_and_mul_post_quant_kernel"), - ("masked_index_gather", "masked_index_gather_kernel"), - ("finalizeMoeRouting", "tensorrt_llm::kernels::cutlass_kernels::finalizeMoeRoutingKernel<"), - ("fused_qkvzba_split", "fused_qkvzba_split_reshape_cat_kernel"), - ("causal_conv1d_update", "tensorrt_llm::kernels::causal_conv1d::causal_conv1d_update_kernel<"), - ("fused_delta_rule_update", "fused_sigmoid_gating_delta_rule_update_kernel"), - ("layer_norm_fwd_1pass", "_layer_norm_fwd_1pass_kernel"), - ("torchGatherTopK", "at::native::sbtopk::gatherTopK<"), - ("softmax_warp_forward", "softmax_warp_forward<"), - ("torchSigmoid", "at::native::sigmoid_kernel_cuda"), - ("torchMul", "at::native::binary_internal::MulFunctor<"), - ("computeSeqAndPaddingOffsets", "tensorrt_llm::kernels::computeSeqAndPaddingOffsets<"), - ("applyBiasRopeUpdateKVCache", "tensorrt_llm::kernels::applyBiasRopeUpdateKVCacheV2<"), - ("routingIndicesHistogramScores", "routingRenormalize::routingIndicesHistogramScoresKernel<"), - ("routingIndicesHistogram", "routingIndicesHistogramKernel<"), - ("routingIndicesOffsets", "routingIndicesOffsetsKernel<"), - ("torchReduceSum", ["at::native::reduce_kernel<", "at::native::sum_functor<"]), ("CuteDSLMoePermute", "cute_dsl::moePermuteKernel"), ( "CuteDSLGemm", @@ -380,6 +293,19 @@ def shortest_common_supersequence(a, b): "CuteDSLGroupedGemmFinalize", ["cute_dsl_kernels", "blockscaled_contiguous_grouped_gemm_finalize_fusion"], ), + ("torchAdd", "at::native::CUDAFunctorOnSelf_add"), + ("torchAdd", "CUDAFunctor_add"), + ("torchClamp", "at::native::::launch_clamp_scalar("), + ("torchCompare", "at::native::::CompareFunctor<"), + ("torchCopy", "at::native::bfloat16_copy_kernel_cuda"), + ("torchCopy", "at::native::direct_copy_kernel_cuda("), + ("torchFill", "at::native::FillFunctor"), + ("torchIndexPut", "at::native::index_put_kernel_impl<"), + ("torchMul", "at::native::binary_internal::MulFunctor<"), + ("torchPow", "at::native::::pow_tensor_scalar_kernel_impl<"), + ("torchReduceSum", ["at::native::reduce_kernel<", "at::native::sum_functor<"]), + ("torchSigmoid", "at::native::sigmoid_kernel_cuda"), + ("torchWhere", "at::native::::where_kernel_impl("), ] warned_names = set() @@ -395,15 +321,19 @@ def parse_kernel_name(demangledName): src = [src] if all(keyword in name for keyword in src): return dst - if name not in warned_names: - print(f"Unknown kernel name: {name}", file=sys.stderr) - warned_names.add(name) - if args.error_on_unknown_kernel: - raise NotImplementedError(f"Unknown kernel name: {name}") + if re.search(r"at::native::.*elementwise_kernel<", name): + if name not in warned_names: + print(f"Not parsed torch kernel name: {name}", file=sys.stderr) + warned_names.add(name) + assert "!unnamed!" not in name + name = name.replace("", "!unnamed!") if "<" in name: name = name[: name.index("<")] if "(" in name: name = name[: name.index("(")] + if "::" in name: + name = name[name.rindex("::") + 2 :] + name = name.replace("!unnamed!", "") return name @@ -438,6 +368,8 @@ def parse_kernel_name(demangledName): converted_seq.append((("Space",), np.mean(space_list[warmup_times:]).tolist())) converted_seq.append((("Total",), sum(t for _, t in converted_seq))) converted_seqs.append(converted_seq) +if args.error_on_unknown_kernel and warned_names: + raise ValueError("Unknown kernel names encountered") merged_title = [] for converted_seq in converted_seqs: @@ -459,7 +391,7 @@ def parse_kernel_name(demangledName): for problem in problem_set: print( f'- "{problem["text"]}" {len(problem["runs"])} runs' - f" Ranges: [{', '.join(text for _, _, text in problem['ranges'])}]" + f" Ranges: [{', '.join(text for _, end, text in problem['ranges'] if end <= problem['runs_end'][0])}]" ) stack = [] diff --git a/examples/layer_wise_benchmarks/run.py b/examples/layer_wise_benchmarks/run.py index d84525c1d33..889cbd81f3b 100644 --- a/examples/layer_wise_benchmarks/run.py +++ b/examples/layer_wise_benchmarks/run.py @@ -11,12 +11,10 @@ from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.distributed import MPIDist, TorchDist -from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE -from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size from tensorrt_llm.logger import logger -from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges +from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, Runner, mark_ranges def comma_separated_ints(s): @@ -46,9 +44,17 @@ def comma_separated_floats(s): group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp") group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp") parser.set_defaults(enable_attention_dp=None) +parser.add_argument("--kv-cache-dtype", type=str, choices=["fp8", "nvfp4", "auto"]) +parser.add_argument( + "--mamba-ssm-cache-dtype", type=str, choices=["auto", "float16", "bfloat16", "float32"] +) # Model init args +parser.add_argument("--load-format", type=str, choices=["AUTO", "DUMMY"]) parser.add_argument("--max-num-tokens", type=int) parser.add_argument("--moe-backend", type=str) +parser.add_argument( + "--moe-backend-for-prefill", type=str, choices=["CUTLASS", "DEEPGEMM", "WIDEEP"] +) parser.add_argument("--moe-max-num-tokens", type=int) group = parser.add_mutually_exclusive_group() group.add_argument( @@ -110,8 +116,16 @@ def comma_separated_floats(s): args.max_seq_len = max(args.seq_len_q_list) + max(args.seq_len_kv_cache_list) if args.enable_attention_dp is None: args.enable_attention_dp = False +if args.kv_cache_dtype is None: + args.kv_cache_dtype = "auto" +if args.mamba_ssm_cache_dtype is None: + args.mamba_ssm_cache_dtype = "auto" +if args.load_format is None: + args.load_format = "DUMMY" if args.max_num_tokens is None: args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list) +if args.moe_backend_for_prefill is None: + args.moe_backend_for_prefill = "CUTLASS" if args.use_low_precision_moe_combine is None: args.use_low_precision_moe_combine = False if args.enable_autotuner is None: @@ -128,7 +142,6 @@ def comma_separated_floats(s): # Create KV cache manager logger.info("Layer-wise benchmarks: Create KV cache manager") -Runner = get_runner_cls(args.model) mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp) kv_cache_manager = Runner.create_kv_cache_manager( args.model, @@ -136,6 +149,8 @@ def comma_separated_floats(s): tokens_per_block=args.tokens_per_block, max_batch_size=args.max_batch_size, max_seq_len=args.max_seq_len, + kv_cache_dtype=args.kv_cache_dtype, + mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype, layer_indices=args.layer_indices, ) attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8) @@ -151,12 +166,15 @@ def comma_separated_floats(s): runner = Runner( args.model, mapping, + load_format=args.load_format, moe_backend=args.moe_backend, layer_indices=args.layer_indices, scaled_from=args.scaled_from, max_seq_len=args.max_seq_len, max_num_tokens=args.max_num_tokens, moe_max_num_tokens=args.moe_max_num_tokens, + kv_cache_dtype=args.kv_cache_dtype, + mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype, use_low_precision_moe_combine=args.use_low_precision_moe_combine, use_cuda_graph=args.use_cuda_graph, ) @@ -190,18 +208,19 @@ def comma_separated_floats(s): max(1, 20480 // ctx_seq_len_q), ) ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8) - with mock.patch.object( - CutlassFusedMoE, "select_alltoall_method_type", return_value=AlltoallMethodType.NotEnabled - ): + with mock.patch.dict(os.environ, {"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled"}, clear=False): ctx_runner = Runner( args.model, mapping, - moe_backend="CUTLASS", + load_format=args.load_format, + moe_backend=args.moe_backend_for_prefill, layer_indices=args.layer_indices, scaled_from=args.scaled_from, max_seq_len=args.max_seq_len, max_num_tokens=ctx_batch_size * ctx_seq_len_q, moe_max_num_tokens=16384, + kv_cache_dtype=args.kv_cache_dtype, + mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype, use_low_precision_moe_combine=args.use_low_precision_moe_combine, use_cuda_graph=False, ) @@ -221,10 +240,7 @@ def comma_separated_floats(s): kv_cache_manager=kv_cache_manager, attn_workspace=ctx_attn_workspace, ) - with ctx_runner.replace_routing_method_ctx( - balance_method=BalanceMethod.Balanced, balance_ratio=None - ): - run_pack(check=True) + run_pack(check=True) del ctx_runner del ctx_attn_workspace logger.info("Layer-wise benchmarks: Prefill KV cache ... Done") @@ -293,6 +309,7 @@ def comma_separated_floats(s): with runner.replace_routing_method_ctx( balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio ): + run_pack() if args.use_cuda_graph: with with_multi_stream(True): g = torch.cuda.CUDAGraph() diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_next.py b/tensorrt_llm/_torch/models/modeling_qwen3_next.py index 13318b1e4f2..799d4076ab5 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_next.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_next.py @@ -833,7 +833,7 @@ def _compute_projected_states_ba(): return output -class Qwen3NextLinearDecoderLayer(nn.Module): +class Qwen3NextLinearDecoderLayer(DecoderLayer): def __init__( self, @@ -1255,6 +1255,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): new_weights = weight_mapper.preprocess_weights(weights) super().load_weights(new_weights, weight_mapper) + def post_load_weights(self): for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py b/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py index e347df3ca8e..f68e0884220 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py @@ -1,5 +1,4 @@ from .mark_utils import mark_ranges -from .runner_factory import get_runner_cls -from .runner_interface import BalanceMethod +from .runner import BalanceMethod, Runner -__all__ = ["BalanceMethod", "get_runner_cls", "mark_ranges"] +__all__ = ["BalanceMethod", "Runner", "mark_ranges"] diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py deleted file mode 100644 index 0e173fc779d..00000000000 --- a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import List, Optional - -import torch - -from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3DecoderLayer -from tensorrt_llm._torch.modules.rms_norm import RMSNorm -from tensorrt_llm._torch.pyexecutor.model_loader import initialize_dummy_weights -from tensorrt_llm._torch.utils import AuxStreamType -from tensorrt_llm.functional import AllReduceStrategy -from tensorrt_llm.mapping import Mapping - -from .runner_interface import RunnerBase -from .runner_utils import RunnerMixin - - -class DeepSeekV3Runner(RunnerMixin, RunnerBase): - @staticmethod - def has_mamba_metadata() -> bool: - return False - - def __init__( - self, - pretrained_model_name_or_path: str, - mapping: Mapping, - *, - moe_backend: str, - layer_indices: List[int], - scaled_from: Optional[int], - max_seq_len: int, - max_num_tokens: int, - moe_max_num_tokens: int, - use_low_precision_moe_combine: bool, - use_cuda_graph: bool, - ): - super().__init__() - self.model_config = ModelConfig.from_pretrained( - pretrained_model_name_or_path, - mapping=mapping, - enable_min_latency=False, - use_cuda_graph=use_cuda_graph, - force_dynamic_quantization=False, - spec_config=None, - sparse_attention_config=None, # To be loaded from config - max_num_tokens=max_num_tokens, - max_seq_len=max_seq_len, - moe_max_num_tokens=moe_max_num_tokens, - moe_load_balancer=None, - lora_config=None, - allreduce_strategy=AllReduceStrategy.AUTO, - mm_encoder_only=False, - attn_backend="TRTLLM", - moe_backend=moe_backend, - moe_disable_finalize_fusion=False, - use_low_precision_moe_combine=use_low_precision_moe_combine, - skip_create_weights_in_init=True, - ) - pretrained_config = self.model_config.pretrained_config - - with self.scaled_from_ctx(scaled_from, mapping, pretrained_config): - aux_stream_list = [torch.cuda.Stream() for _ in range(2)] - aux_stream_dict = { - AuxStreamType.Attention: aux_stream_list[0], - AuxStreamType.MoeShared: aux_stream_list[0], - AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], - } - - layers = [ - DeepseekV3DecoderLayer( - model_config=self.model_config, - layer_idx=layer_idx, - aux_stream_dict=aux_stream_dict, - ) - for layer_idx in layer_indices - ] - next_layer_layernorm = RMSNorm( - hidden_size=pretrained_config.hidden_size, - eps=pretrained_config.rms_norm_eps, - dtype=pretrained_config.torch_dtype, - ) - - # TODO: apply_layerwise_quant_config - self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config) - for layer in layers: - for module in layer.modules(): - if callable(getattr(module, "create_weights", None)): - module.create_weights() - layer.cuda() - initialize_dummy_weights(layer) - for module in layer.modules(): - if hasattr(module, "post_load_weights") and not getattr( - module, "_weights_removed", False - ): - module.post_load_weights() - next_layer_layernorm.cuda() - initialize_dummy_weights(next_layer_layernorm) - for layer, next_layer in zip(layers[:-1], layers[1:]): - layer.next_layer_layernorm = next_layer.input_layernorm - layers[-1].next_layer_layernorm = next_layer_layernorm - - self.layers = layers diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py b/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py index 7ebde93e088..6b380c12ec2 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py @@ -1,6 +1,7 @@ import nvtx from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate, Deepseekv3MoE +from tensorrt_llm._torch.models.modeling_nemotron_h import MLPLayer, NemotronHMOE from tensorrt_llm._torch.models.modeling_qwen3_next import ( Qwen3NextGatedDeltaNet, Qwen3NextSparseMoeBlock, @@ -8,11 +9,14 @@ from tensorrt_llm._torch.modules.attention import MLA, Attention from tensorrt_llm._torch.modules.fused_moe.interface import MoE from tensorrt_llm._torch.modules.gated_mlp import GatedMLP +from tensorrt_llm._torch.modules.mamba.mamba2_mixer import Mamba2Mixer def mark_ranges(): DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward) Deepseekv3MoE.forward = nvtx.annotate("Deepseekv3MoE")(Deepseekv3MoE.forward) + MLPLayer.forward = nvtx.annotate("MLPLayer")(MLPLayer.forward) + NemotronHMOE.forward = nvtx.annotate("NemotronHMOE")(NemotronHMOE.forward) Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")( Qwen3NextGatedDeltaNet.forward ) @@ -23,3 +27,4 @@ def mark_ranges(): Attention.forward = nvtx.annotate("Attention")(Attention.forward) MoE.forward = nvtx.annotate("MoE")(MoE.forward) GatedMLP.forward = nvtx.annotate("GatedMLP")(GatedMLP.forward) + Mamba2Mixer.forward = nvtx.annotate("Mamba2Mixer")(Mamba2Mixer.forward) diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py deleted file mode 100644 index 3ebd7487fdb..00000000000 --- a/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import List, Optional - -import torch - -from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.models.modeling_qwen3_next import ALL_DECODER_LAYER_TYPES -from tensorrt_llm._torch.modules.rms_norm import RMSNorm -from tensorrt_llm._torch.pyexecutor.model_loader import initialize_dummy_weights -from tensorrt_llm.functional import AllReduceStrategy -from tensorrt_llm.mapping import Mapping - -from .runner_interface import RunnerBase -from .runner_utils import RunnerMixin - - -class Qwen3NextRunner(RunnerMixin, RunnerBase): - @staticmethod - def has_mamba_metadata() -> bool: - return True - - def __init__( - self, - pretrained_model_name_or_path: str, - mapping: Mapping, - *, - moe_backend: str, - layer_indices: List[int], - scaled_from: Optional[int], - max_seq_len: int, - max_num_tokens: int, - moe_max_num_tokens: int, - use_low_precision_moe_combine: bool, - use_cuda_graph: bool, - ): - super().__init__() - self.model_config = ModelConfig.from_pretrained( - pretrained_model_name_or_path, - mapping=mapping, - enable_min_latency=False, - use_cuda_graph=use_cuda_graph, - force_dynamic_quantization=False, - spec_config=None, - sparse_attention_config=None, # To be loaded from config - max_num_tokens=max_num_tokens, - max_seq_len=max_seq_len, - moe_max_num_tokens=moe_max_num_tokens, - moe_load_balancer=None, - lora_config=None, - allreduce_strategy=AllReduceStrategy.AUTO, - mm_encoder_only=False, - attn_backend="TRTLLM", - moe_backend=moe_backend, - moe_disable_finalize_fusion=False, - use_low_precision_moe_combine=use_low_precision_moe_combine, - skip_create_weights_in_init=True, - ) - pretrained_config = self.model_config.pretrained_config - - with self.scaled_from_ctx(scaled_from, mapping, pretrained_config): - aux_stream = torch.cuda.Stream() - layers = [ - ALL_DECODER_LAYER_TYPES[pretrained_config.layer_types[layer_idx]]( - self.model_config, - layer_idx, - aux_stream, - ) - for layer_idx in layer_indices - ] - next_layer_layernorm = RMSNorm( - hidden_size=pretrained_config.hidden_size, - eps=pretrained_config.rms_norm_eps, - dtype=pretrained_config.torch_dtype, - use_gemma=True, - ) - - # TODO: apply_layerwise_quant_config - self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config) - for layer in layers: - for module in layer.modules(): - if callable(getattr(module, "create_weights", None)): - module.create_weights() - layer.cuda() - initialize_dummy_weights(layer) - for module in layer.modules(): - if hasattr(module, "post_load_weights") and not getattr( - module, "_weights_removed", False - ): - module.post_load_weights() - next_layer_layernorm.cuda() - initialize_dummy_weights(next_layer_layernorm) - for layer, next_layer in zip(layers[:-1], layers[1:]): - layer.next_layer_layernorm = next_layer.input_layernorm - layers[-1].next_layer_layernorm = next_layer_layernorm - - self.layers = layers diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py similarity index 68% rename from tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py rename to tensorrt_llm/tools/layer_wise_benchmarks/runner.py index 93d6d84e115..4b6f9050e64 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py @@ -1,34 +1,49 @@ import contextlib import functools import itertools -import os import unittest.mock import weakref -from abc import ABC, abstractmethod +from enum import IntEnum from typing import Optional import torch +import tensorrt_llm._torch.model_config +import tensorrt_llm.bindings from tensorrt_llm._torch.attention_backend.utils import get_attention_backend from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_utils import PostInitCaller, skip_forward from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_trtllm_gen import TRTLLMGenFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE -from tensorrt_llm._torch.modules.linear import Linear, WeightMode from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls -from tensorrt_llm._torch.pyexecutor.config_utils import is_mla, is_qwen3_next +from tensorrt_llm._torch.pyexecutor.config_utils import ( + is_mla, + is_nemotron_hybrid, + is_qwen3_next, + load_pretrained_config, +) +from tensorrt_llm._torch.pyexecutor.model_loader import ( + ModelLoader, + _construct_checkpoint_loader, + validate_and_set_kv_cache_quant, + validate_and_set_mamba_ssm_cache_dtype, +) from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding -from tensorrt_llm.bindings.executor import KvCacheConfig -from tensorrt_llm.bindings.internal.batch_manager import CacheType +from tensorrt_llm.llmapi.llm_args import KvCacheConfig, MoeConfig, TorchLlmArgs from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.modeling_utils import QuantConfig -from .runner_interface import BalanceMethod + +class BalanceMethod(IntEnum): + NotModified = 1 + Balanced = 2 + ImbalancedRanks = 3 + ImbalancedExperts = 4 def ceil_div(a, b): @@ -102,31 +117,28 @@ def test_get_balanced_selection(): raise ValueError("tokens per expert is not balanced") -def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size): - num_tokens, top_k = imbalanced_experts.shape - dtype = imbalanced_experts.dtype - device = imbalanced_experts.device - balanced_experts = get_balanced_selection_no_cache( - num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size - ) +def get_num_balanced_tokens(num_tokens, top_k, num_experts, dp_size, balance_ratio): if balance_ratio == 0.0: - num_balanced_tokens = 0 + return 0 else: # Activate all experts min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k)) - num_balanced_tokens = min_num_balanced_tokens + round( + return min_num_balanced_tokens + round( (num_tokens - min_num_balanced_tokens) * balance_ratio ) - mixed_experts = torch.cat( - [balanced_experts[:num_balanced_tokens], imbalanced_experts[num_balanced_tokens:]] - ) - return mixed_experts @functools.cache def get_all_to_one_selection( num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size ): + num_balanced_tokens = get_num_balanced_tokens( + num_tokens, top_k, num_experts, dp_size, balance_ratio + ) + balanced_experts = get_balanced_selection_no_cache( + num_balanced_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size + ) + num_imbalanced_tokens = num_tokens - num_balanced_tokens experts_per_rank = num_experts // ep_size if top_k > experts_per_rank: raise ValueError( @@ -134,29 +146,34 @@ def get_all_to_one_selection( ) imbalanced_experts = ( torch.arange( - dp_rank * num_tokens * top_k, - (dp_rank + 1) * num_tokens * top_k, + dp_rank * num_imbalanced_tokens * top_k, + (dp_rank + 1) * num_imbalanced_tokens * top_k, dtype=dtype, device=device, - ).view(num_tokens, top_k) + ).view(num_imbalanced_tokens, top_k) % experts_per_rank ) - imbalanced_experts = imbalanced_experts.sort(dim=-1).values - return apply_balance_ratio( - imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size - ) + mixed_experts = torch.cat([balanced_experts, imbalanced_experts]) + return mixed_experts.sort(dim=-1).values @functools.cache def get_balanced_rank_imbalanced_expert_selection( num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size ): + num_balanced_tokens = get_num_balanced_tokens( + num_tokens, top_k, num_experts, dp_size, balance_ratio + ) + balanced_experts = get_balanced_selection_no_cache( + num_balanced_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size + ) + num_imbalanced_tokens = num_tokens - num_balanced_tokens experts_per_rank = num_experts // ep_size active_experts_per_rank = ceil_div(top_k, ep_size) # Select expert from [0, active_experts_per_rank * ep_size), # then scale to [0, experts_per_rank * ep_size) narrow_experts = get_balanced_selection_no_cache( - num_tokens, + num_imbalanced_tokens, top_k, active_experts_per_rank * ep_size, dtype, @@ -169,9 +186,8 @@ def get_balanced_rank_imbalanced_expert_selection( narrow_experts // active_experts_per_rank * experts_per_rank + narrow_experts % active_experts_per_rank ) - return apply_balance_ratio( - imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size - ) + mixed_experts = torch.cat([balanced_experts, imbalanced_experts]) + return mixed_experts.sort(dim=-1).values def make_balanced_routing_method( @@ -339,36 +355,91 @@ def forward_impl(*args, **kwargs): return forward_impl -class RunnerMixin(ABC): - @staticmethod - @abstractmethod - def has_mamba_metadata() -> bool: - pass +class Runner: + def __init__( + self, + pretrained_model_name_or_path: str, + mapping: Mapping, + *, + load_format: str, + moe_backend: str, + layer_indices: list[int], + scaled_from: Optional[int], + max_seq_len: int, + max_num_tokens: int, + moe_max_num_tokens: int, + kv_cache_dtype, + mamba_ssm_cache_dtype: str, + use_low_precision_moe_combine: bool, + use_cuda_graph: bool, + ): + super().__init__() + + checkpoint_loader = _construct_checkpoint_loader("pytorch", None, "HF") + # Please refer to `tensorrt_llm/_torch/pyexecutor/model_loader.py` for effective args + llm_args = TorchLlmArgs( + model=pretrained_model_name_or_path, + load_format=load_format, + **{} if use_cuda_graph else {"cuda_graph_config": None}, + moe_config=MoeConfig( + backend=moe_backend, + max_num_tokens=moe_max_num_tokens, + disable_finalize_fusion=False, + use_low_precision_moe_combine=use_low_precision_moe_combine, + ), + attn_backend="TRTLLM", + kv_cache_config=KvCacheConfig( + dtype=kv_cache_dtype, mamba_ssm_cache_dtype=mamba_ssm_cache_dtype + ), + ) + model_loader = ModelLoader( + llm_args=llm_args, + mapping=mapping, + spec_config=None, + sparse_attention_config=None, + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + ) + + with self.scaled_from_ctx(scaled_from, mapping), self.skip_unused_layers_ctx(layer_indices): + model, _ = model_loader.load( + checkpoint_dir=pretrained_model_name_or_path, checkpoint_loader=checkpoint_loader + ) + + self.layers = [model.model.layers[i] for i in layer_indices] + self.model_config = model.model_config @staticmethod @contextlib.contextmanager - def scaled_from_ctx(scaled_from, mapping, pretrained_config): + def scaled_from_ctx(scaled_from, mapping): if scaled_from is None: yield return - # To run the problem size of $B$ GPUs on $A$ GPUs, we need: - # (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change. - # (2) MoE: If EP, reduce the number of experts; If TP, reduce head size. - # Maintain the result of AllToAll method selection because it is affected by EP size. - if not mapping.enable_attention_dp: - if hasattr(pretrained_config, "index_n_heads"): - raise NotImplementedError("Not support Indexer TP for weak scaling") - pretrained_config.num_attention_heads = ( - pretrained_config.num_attention_heads // scaled_from * mapping.tp_size - ) - pretrained_config.num_key_value_heads = ( - pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size - ) - if mapping.moe_ep_size != mapping.world_size: - raise NotImplementedError("Not support MoE TP for weak scaling") - pretrained_config.n_routed_experts = ( - pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size - ) + + def make_load_pretrained_config(mapping, load_pretrained_config_orig): + # To run the problem size of $B$ GPUs on $A$ GPUs, we need: + # (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change. + # (2) MoE: If EP, reduce the number of experts; If TP, reduce head size. + # Maintain the result of AllToAll method selection because it is affected by EP size. + def load_pretrained_config(*args, **kwargs): + pretrained_config = load_pretrained_config_orig(*args, **kwargs) + if not mapping.enable_attention_dp: + if hasattr(pretrained_config, "index_n_heads"): + raise NotImplementedError("Not support Indexer TP for weak scaling") + pretrained_config.num_attention_heads = ( + pretrained_config.num_attention_heads // scaled_from * mapping.tp_size + ) + pretrained_config.num_key_value_heads = ( + pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size + ) + if mapping.moe_ep_size != mapping.tp_size: + raise NotImplementedError("Not support MoE TP for weak scaling") + pretrained_config.n_routed_experts = ( + pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size + ) + return pretrained_config + + return load_pretrained_config def make_select_alltoall_method_type(select_alltoall_method_type_orig): def select_alltoall_method_type( @@ -408,6 +479,9 @@ def select_alltoall_method_type(self): select_alltoall_method_type_cutlass = CutlassFusedMoE.select_alltoall_method_type select_alltoall_method_type_trtllm_gen = TRTLLMGenFusedMoE.select_alltoall_method_type select_alltoall_method_type_wide_ep = WideEPMoE.select_alltoall_method_type + tensorrt_llm._torch.model_config.load_pretrained_config = make_load_pretrained_config( + mapping, load_pretrained_config + ) CutlassFusedMoE.select_alltoall_method_type = make_select_alltoall_method_type_2( select_alltoall_method_type_cutlass ) @@ -420,40 +494,50 @@ def select_alltoall_method_type(self): try: yield finally: + tensorrt_llm._torch.model_config.load_pretrained_config = load_pretrained_config CutlassFusedMoE.select_alltoall_method_type = select_alltoall_method_type_cutlass TRTLLMGenFusedMoE.select_alltoall_method_type = select_alltoall_method_type_trtllm_gen WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_wide_ep @staticmethod - def apply_quant_config_exclude_modules(layers, quant_config): - # Please refer to tensorrt_llm/_torch/models/modeling_utils.py - new_quant_config = QuantConfig(kv_cache_quant_algo=quant_config.kv_cache_quant_algo) - for layer in layers: - for name, module in layer.named_modules(): - name = f"model.layers.{layer.layer_idx}.{name}" - candidates = [name] - if isinstance(module, Linear): - weight_mode = module.weights_loading_config.weight_mode - if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR: - # sometimes gate and up proj are not packed in the checkpoint, - # but they still share the same exclusion rule - candidates += [ - name.replace("gate_up_proj", "gate_proj"), - name.replace("gate_up_proj", "up_proj"), - ] - elif weight_mode == WeightMode.FUSED_QKV_LINEAR: - # sometimes q_proj, k_proj and v_proj are not packed in the checkpoint, - # but they still share the same exclusion rule - candidates += [ - name.replace("qkv_proj", "q_proj"), - name.replace("qkv_proj", "k_proj"), - name.replace("qkv_proj", "v_proj"), - ] - is_excluded = any( - quant_config.is_module_excluded_from_quantization(n) for n in candidates + @contextlib.contextmanager + def skip_unused_layers_ctx(layer_indices): + call_orig = PostInitCaller.__call__ + + def call_new(cls, *args, **kwargs): + model = call_orig(cls, *args, **kwargs) + for module in ( + model.prologue + model.model.prologue + model.model.epilogue + model.epilogue + ): + skip_forward(module) + num_hidden_layers = model.model_config.pretrained_config.num_hidden_layers + if hasattr(model.model, "embed_tokens"): + skip_forward(model.model.embed_tokens) + for layer_idx in range(num_hidden_layers): + layer = model.model.layers[layer_idx] + if layer_idx not in layer_indices: + # keep next layer's input_layernorm's weights for fusion + skip_forward( + layer, + ignore_modules=[layer.input_layernorm] + if layer_idx - 1 in layer_indices + and hasattr(model.model.layers[layer_idx - 1], "next_layer_layernorm") + else None, + ) + if hasattr(model.model, "norm"): + skip_forward( + model.model.norm, + ignore_modules=[model.model.norm] + if num_hidden_layers - 1 in layer_indices + else None, ) - if is_excluded and getattr(module, "quant_config", None) is not None: - module.quant_config = new_quant_config + return model + + PostInitCaller.__call__ = call_new + try: + yield + finally: + PostInitCaller.__call__ = call_orig def create_run_pack( self, @@ -466,9 +550,8 @@ def create_run_pack( kv_cache_manager: KVCacheManager, attn_workspace: Optional[torch.Tensor] = None, ): - if self.model_config.moe_backend == "TRTLLM" and os.getenv("TRTLLM_ENABLE_PDL") != "1": - raise ValueError("Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM") world_size = mpi_world_size() + pretrained_config = self.model_config.pretrained_config AttentionCls = get_attention_backend( self.model_config.attn_backend, self.model_config.sparse_attention_config ) @@ -499,7 +582,7 @@ def create_run_pack( ) attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q] * world_size attn_metadata.prepare() - hidden_size = self.model_config.pretrained_config.hidden_size + hidden_size = pretrained_config.hidden_size position_ids = torch.tensor( [list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * batch_size], dtype=torch.int32, @@ -513,9 +596,14 @@ def create_run_pack( ) kwargs = {} - if self.has_mamba_metadata(): - # Please refer to `tensorrt_llm/_torch/models/modeling_qwen3_next.py` for `mamba_metadata` - mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests, chunk_size=128) + if is_nemotron_hybrid(pretrained_config) or is_qwen3_next(pretrained_config): + # Please refer to `tensorrt_llm/_torch/models/modeling_qwen3_next.py` for the magic number chunk_size=128 + mamba_metadata = Mamba2Metadata( + attn_metadata.max_num_requests, + chunk_size=128 + if is_qwen3_next(pretrained_config) + else pretrained_config.chunk_size, + ) mamba_metadata.prepare(attn_metadata) kwargs["mamba_metadata"] = mamba_metadata @@ -524,8 +612,15 @@ def run_pack(*, check=False): with model_extra_attrs(self.model_config.extra_attrs): get_model_extra_attrs()["attention_metadata"] = weakref.ref(attn_metadata) with torch.inference_mode(): + # TODO: to be more general, we should call DecoderModel.forward for layer in self.layers: - output = layer(position_ids, output[0], attn_metadata, output[1], **kwargs) + residual_fusion = hasattr(layer, "next_layer_layernorm") + if residual_fusion: + output = layer( + position_ids, output[0], attn_metadata, output[1], **kwargs + ) + else: + output = layer(position_ids, output[0], attn_metadata, **kwargs), None if check: if output[0].isnan().any(): raise ValueError("Has nan, please fix weights initialization") @@ -554,12 +649,20 @@ def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_rati f' please set balance_method to "NotModified"' ) original_methods = [] - dp_rank = self.model_config.mapping.rank // ( - self.model_config.mapping.world_size // self.model_config.mapping.dp_size + dp_rank = ( + self.model_config.mapping.tp_rank + if self.model_config.mapping.enable_attention_dp + else 0 ) + moe_modules = [] for layer in self.layers: - moe_module = layer.mlp.experts + if layer.__class__.__name__ == "NemotronHLayer": + if layer.layer_type == "E": + moe_modules.append(layer.mixer.experts) + else: + moe_modules.append(layer.mlp.experts) + for moe_module in moe_modules: # Replace `routing_method.apply` for normal cases apply_method_orig = moe_module.routing_method.apply moe_module.routing_method.apply = make_balanced_routing_method( @@ -579,8 +682,8 @@ def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_rati moe_module.run_moe = make_balanced_run_moe( moe_module, run_moe_orig, - layer.mlp.experts.routing_method.top_k, - layer.mlp.experts.num_experts, + moe_module.routing_method.top_k, + moe_module.num_experts, balance_method, balance_ratio, self.model_config.mapping.dp_size, @@ -598,10 +701,9 @@ def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_rati try: yield finally: - for layer, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip( - self.layers, original_methods + for moe_module, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip( + moe_modules, original_methods ): - moe_module = layer.mlp.experts moe_module.routing_method.apply = apply_method_orig if isinstance(moe_module, TRTLLMGenFusedMoE): moe_module.run_moe = run_moe_orig @@ -614,10 +716,14 @@ def create_kv_cache_manager( tokens_per_block, max_batch_size, max_seq_len, + kv_cache_dtype, + mamba_ssm_cache_dtype, layer_indices, ): # Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block` model_config = ModelConfig.from_pretrained(pretrained_model_name_or_path) + validate_and_set_kv_cache_quant(model_config, kv_cache_dtype) + validate_and_set_mamba_ssm_cache_dtype(model_config, mamba_ssm_cache_dtype) if model_config.enable_flash_mla: assert tokens_per_block == 64 @@ -628,18 +734,17 @@ def create_kv_cache_manager( max_tokens=max_batch_size * round_up(max_seq_len, tokens_per_block), enable_block_reuse=False, ) - kv_cache_dtype = torch_dtype_to_binding( - { - None: torch.bfloat16, - "FP8": torch.float8_e4m3fn, - }[model_config.quant_config.kv_cache_quant_algo] - ) + kv_cache_dtype = { + "FP8": tensorrt_llm.bindings.DataType.FP8, + "NVFP4": tensorrt_llm.bindings.DataType.NVFP4, + None: torch_dtype_to_binding(config.torch_dtype), + }[model_config.quant_config.kv_cache_quant_algo] if is_mla(config): layer_mask = [i in layer_indices for i in range(config.num_hidden_layers)] num_layers = sum(layer_mask) kv_cache_manager = kv_cache_manager_cls( kv_cache_config, - CacheType.SELFKONLY, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY, num_layers=num_layers, num_kv_heads=1, head_dim=model_config.pretrained_config.kv_lora_rank @@ -649,9 +754,46 @@ def create_kv_cache_manager( max_batch_size=max_batch_size, mapping=mapping, dtype=kv_cache_dtype, + spec_config=None, layer_mask=layer_mask, sparse_attn_config=model_config.sparse_attention_config, ) + elif is_nemotron_hybrid(config): + mamba_layer_mask = [ + i in layer_indices and char == "M" + for i, char in enumerate(config.hybrid_override_pattern) + ] + layer_mask = [ + i in layer_indices and char == "*" + for i, char in enumerate(config.hybrid_override_pattern) + ] + num_mamba_layers = sum(mamba_layer_mask) + num_layers = sum(layer_mask) + kv_cache_manager = kv_cache_manager_cls( + # mamba cache parameters + config.ssm_state_size, + config.conv_kernel, + config.mamba_num_heads, + config.n_groups, + config.mamba_head_dim, + num_mamba_layers, + mamba_layer_mask, + config.torch_dtype, + model_config.quant_config.mamba_ssm_cache_dtype, + # kv cache parameters + kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, + num_layers=num_layers, + layer_mask=layer_mask, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + spec_config=None, + ) elif is_qwen3_next(config): mamba_layer_mask = [ i in layer_indices @@ -680,7 +822,7 @@ def create_kv_cache_manager( model_config.quant_config.mamba_ssm_cache_dtype, # kv cache parameters kv_cache_config, - CacheType.SELF, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, num_layers=num_layers, layer_mask=layer_mask, num_kv_heads=config.num_key_value_heads, diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py deleted file mode 100644 index b45d1e8e5ba..00000000000 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py +++ /dev/null @@ -1,13 +0,0 @@ -from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config - -from .deepseekv3_runner import DeepSeekV3Runner -from .qwen3_next_runner import Qwen3NextRunner - - -def get_runner_cls(pretrained_model_name_or_path: str) -> type: - pretrained_config = load_pretrained_config(pretrained_model_name_or_path) - return { - "deepseek_v3": DeepSeekV3Runner, - "deepseek_v32": DeepSeekV3Runner, - "qwen3_next": Qwen3NextRunner, - }[pretrained_config.model_type] diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_interface.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_interface.py deleted file mode 100644 index 9451124e20c..00000000000 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner_interface.py +++ /dev/null @@ -1,49 +0,0 @@ -from abc import ABC, abstractmethod -from enum import IntEnum -from typing import Optional - -import torch - -from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager - - -class BalanceMethod(IntEnum): - NotModified = 1 - Balanced = 2 - ImbalancedRanks = 3 - ImbalancedExperts = 4 - - -class RunnerBase(ABC): - @abstractmethod - def create_run_pack( - self, - run_type: str, - batch_size: int, - seq_len_q: int, - seq_len_kv_cache: int, - kv_cache_manager: KVCacheManager, - attn_workspace: Optional[torch.Tensor] = None, - ): - pass - - @abstractmethod - def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float): - pass - - @staticmethod - @abstractmethod - def create_kv_cache_manager( - pretrained_model_name_or_path, - mapping, - tokens_per_block, - max_batch_size, - max_seq_len, - layer_indices, - ): - pass - - @staticmethod - @abstractmethod - def create_mapping(enable_attention_dp: bool): - pass diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 98860efd0b8..d7a434d2762 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -83,6 +83,7 @@ l0_b200: - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_gpt_oss" - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_dep[1] + - unittest/tools/test_layer_wise_benchmarks.py::test_nemotron_gen_dep[1] - unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1] - unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8 - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_nvfp4[enable_configurable_moe-disable_finalize_fusion-TRTLLM-dtype1] diff --git a/tests/unittest/tools/test_layer_wise_benchmarks.py b/tests/unittest/tools/test_layer_wise_benchmarks.py index b5f7e58a901..23310b6854c 100644 --- a/tests/unittest/tools/test_layer_wise_benchmarks.py +++ b/tests/unittest/tools/test_layer_wise_benchmarks.py @@ -54,7 +54,6 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size): **os.environ, "NP": f"{world_size:d}", "PROFILE_DIR": profile_dir, - "TRTLLM_ENABLE_PDL": "1", }, ) check_call( @@ -122,6 +121,35 @@ def test_deepseek_r1_gen_scaled_from_16_dep(llm_root, world_size): ) +@pytest.mark.parametrize("world_size", [1, 4]) +def test_nemotron_gen_dep(llm_root, world_size): + if torch.cuda.device_count() < world_size: + pytest.skip(f"needs {world_size:d} GPUs to run this test") + model_root = llm_models_root(check=True) + profile_dir = f"profiles/test_nemotron_gen_dep_{world_size}" + check_call( + [ + "./mpi_launch.sh", + "./run.sh", + "config_gen.yaml", + "--model", + model_root / "NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", + "--layer-indices=4,5,6", + "--mamba-ssm-cache-dtype=float16", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": f"{world_size:d}", + "PROFILE_DIR": profile_dir, + }, + ) + check_call( + ["python3", "parse.py", "--profile-dir", profile_dir, f"--world-size={world_size}"], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + ) + + @pytest.mark.parametrize("world_size", [1, 4]) def test_qwen3_next_gen_tep(llm_root, world_size): if torch.cuda.device_count() < world_size: @@ -137,6 +165,7 @@ def test_qwen3_next_gen_tep(llm_root, world_size): model_root / "Qwen3" / "Qwen3-Next-80B-A3B-Instruct", "--layer-indices=6,7", "--no-enable-attention-dp", + "--mamba-ssm-cache-dtype=float16", "--moe-backend=TRTLLM", ], cwd=llm_root / "examples" / "layer_wise_benchmarks", @@ -144,7 +173,6 @@ def test_qwen3_next_gen_tep(llm_root, world_size): **os.environ, "NP": f"{world_size:d}", "PROFILE_DIR": profile_dir, - "TRTLLM_ENABLE_PDL": "1", }, ) check_call(