diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index 9cfd9ef55c..fdbc54098c 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -1,17 +1,13 @@ import argparse import sys -from routines.attention import parse_attention_args, run_attention_test +# Only import utilities at module level - routine modules are imported lazily +# to avoid loading unnecessary dependencies (e.g., mpi4py for non-MPI benchmarks) from routines.flashinfer_benchmark_utils import ( benchmark_apis, full_output_columns, output_column_dict, ) -from routines.gemm import parse_gemm_args, run_gemm_test -from routines.moe import parse_moe_args, run_moe_test -from routines.moe_comm import parse_moe_comm_args, run_moe_comm_test -from routines.norm import parse_norm_args, run_norm_test -from routines.quantization import parse_quantization_args, run_quantization_test def run_test(args): @@ -23,17 +19,30 @@ def run_test(args): """ ## Depending on routine type, route to corresponding test routine + ## Imports are done lazily to avoid loading unnecessary dependencies if args.routine in benchmark_apis["attention"]: + from routines.attention import run_attention_test + res = run_attention_test(args) elif args.routine in benchmark_apis["gemm"]: + from routines.gemm import run_gemm_test + res = run_gemm_test(args) elif args.routine in benchmark_apis["moe"]: + from routines.moe import run_moe_test + res = run_moe_test(args) elif args.routine in benchmark_apis["moe_comm"]: + from routines.moe_comm import run_moe_comm_test + res = run_moe_comm_test(args) elif args.routine in benchmark_apis["norm"]: + from routines.norm import run_norm_test + res = run_norm_test(args) elif args.routine in benchmark_apis["quantization"]: + from routines.quantization import run_quantization_test + res = run_quantization_test(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -165,17 +174,30 @@ def parse_args(line=sys.argv[1:]): ) ## Check routine and pass on to routine-specific argument parser + ## Imports are done lazily to avoid loading unnecessary dependencies if args.routine in benchmark_apis["attention"]: + from routines.attention import parse_attention_args + args = parse_attention_args(line, parser) elif args.routine in benchmark_apis["gemm"]: + from routines.gemm import parse_gemm_args + args = parse_gemm_args(line, parser) elif args.routine in benchmark_apis["moe"]: + from routines.moe import parse_moe_args + args = parse_moe_args(line, parser) elif args.routine in benchmark_apis["moe_comm"]: + from routines.moe_comm import parse_moe_comm_args + args = parse_moe_comm_args(line, parser) elif args.routine in benchmark_apis["norm"]: + from routines.norm import parse_norm_args + args = parse_norm_args(line, parser) elif args.routine in benchmark_apis["quantization"]: + from routines.quantization import parse_quantization_args + args = parse_quantization_args(line, parser) else: raise ValueError(f"Unsupported routine: {args.routine}")