Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions benchmarks/flashinfer_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import argparse
import sys

from routines.attention import parse_attention_args, run_attention_test
# Only import utilities at module level - routine modules are imported lazily
# to avoid loading unnecessary dependencies (e.g., mpi4py for non-MPI benchmarks)
from routines.flashinfer_benchmark_utils import (
benchmark_apis,
full_output_columns,
output_column_dict,
)
from routines.gemm import parse_gemm_args, run_gemm_test
from routines.moe import parse_moe_args, run_moe_test
from routines.moe_comm import parse_moe_comm_args, run_moe_comm_test
from routines.norm import parse_norm_args, run_norm_test
from routines.quantization import parse_quantization_args, run_quantization_test


def run_test(args):
Expand All @@ -23,17 +19,30 @@ def run_test(args):
"""

## Depending on routine type, route to corresponding test routine
## Imports are done lazily to avoid loading unnecessary dependencies
if args.routine in benchmark_apis["attention"]:
from routines.attention import run_attention_test

res = run_attention_test(args)
elif args.routine in benchmark_apis["gemm"]:
from routines.gemm import run_gemm_test

res = run_gemm_test(args)
elif args.routine in benchmark_apis["moe"]:
from routines.moe import run_moe_test

res = run_moe_test(args)
elif args.routine in benchmark_apis["moe_comm"]:
from routines.moe_comm import run_moe_comm_test

res = run_moe_comm_test(args)
elif args.routine in benchmark_apis["norm"]:
from routines.norm import run_norm_test

res = run_norm_test(args)
elif args.routine in benchmark_apis["quantization"]:
from routines.quantization import run_quantization_test

res = run_quantization_test(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
Expand Down Expand Up @@ -165,17 +174,30 @@ def parse_args(line=sys.argv[1:]):
)

## Check routine and pass on to routine-specific argument parser
## Imports are done lazily to avoid loading unnecessary dependencies
if args.routine in benchmark_apis["attention"]:
from routines.attention import parse_attention_args

args = parse_attention_args(line, parser)
elif args.routine in benchmark_apis["gemm"]:
from routines.gemm import parse_gemm_args

args = parse_gemm_args(line, parser)
elif args.routine in benchmark_apis["moe"]:
from routines.moe import parse_moe_args

args = parse_moe_args(line, parser)
elif args.routine in benchmark_apis["moe_comm"]:
from routines.moe_comm import parse_moe_comm_args

args = parse_moe_comm_args(line, parser)
elif args.routine in benchmark_apis["norm"]:
from routines.norm import parse_norm_args

args = parse_norm_args(line, parser)
elif args.routine in benchmark_apis["quantization"]:
from routines.quantization import parse_quantization_args

args = parse_quantization_args(line, parser)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
Expand Down
Loading