diff --git a/benchmarks/prototype/moe_training/benchmark_moe_layer.py b/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py similarity index 58% rename from benchmarks/prototype/moe_training/benchmark_moe_layer.py rename to benchmarks/prototype/moe_training/benchmark_moe_fsdp.py index d18c6dc176..84453fa242 100644 --- a/benchmarks/prototype/moe_training/benchmark_moe_layer.py +++ b/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py @@ -14,8 +14,6 @@ import argparse import copy import os -import statistics -from time import perf_counter_ns import pytest import torch @@ -24,6 +22,11 @@ from torch.distributed._composable.fsdp import fully_shard from torch.nn import functional as F +from benchmarks.prototype.moe_training.utils import ( + bench_fwd_bwd_microseconds, + profile_fn, +) + # this feature requires CUDA and SM89+ if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): pytest.skip( @@ -48,8 +51,12 @@ ) -def bench_moe_float8_training_fsdp(enable_profile=False): +def bench_moe_float8_training_fsdp( + recipe_name: str, enable_profile: bool, use_compile: bool +): assert torch.cuda.is_available() + assert recipe_name in ["fp8_rowwise", "mxfp8"] + recipe = MoEScalingType[recipe_name.upper()] # setup distributed for fsdp setup_distributed() @@ -62,8 +69,8 @@ def bench_moe_float8_training_fsdp(enable_profile=False): init_std = 0.02 device = torch.device("cuda") - # reference bf16 MoE - dim, hidden_dim = 5120, 4 * 5120 + # reference bf16 MoE using llama4 shapes + dim, hidden_dim = 5120, 8192 ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -71,6 +78,10 @@ def bench_moe_float8_training_fsdp(enable_profile=False): # target MoE for testing conversion model = copy.deepcopy(ref_model) + # Token group alignment size must be 16 for fp8 rowwise training + alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16 + set_token_group_alignment_size_m(alignment_size) + # assert starting params are identical for both models for param1, param2 in zip(model.parameters(), ref_model.parameters()): assert torch.equal(param1, param2) @@ -83,7 +94,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) + config = MoETrainingConfig(scaling_type=recipe) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # FSDP2 @@ -91,7 +102,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: fully_shard(ref_model) # inputs (llama4 shapes) - batch, seq = 1, 8192 + batch, seq = 1, 16640 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -104,70 +115,34 @@ def warmup(model, input): loss.backward() torch.cuda.synchronize() - def bench_fn_microseconds(model, input): - labels = torch.ones_like(input) - times = [] - for _ in range(10): - start_ns = perf_counter_ns() - out = model(input) - loss = F.mse_loss(out, labels) - loss.backward() - torch.cuda.synchronize() - end_ns = perf_counter_ns() - duration_us = (end_ns - start_ns) / 1000 - times.append(duration_us) - return statistics.median(times) - - def profile_fn(model, input, profile_name="profile"): - # Only profile on rank 0 - if torch.distributed.get_rank() == 0: - labels = torch.ones_like(input) - wait, warmup, active = 1, 3, 1 - total_steps = wait + warmup + active - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=wait, warmup=warmup, active=active, repeat=0 - ), - record_shapes=True, - with_stack=True, - ) as prof: - for _ in range(total_steps): - out = model(input) - loss = F.mse_loss(out, labels) - loss.backward() - prof.step() - - # Save profiler results - prof.export_chrome_trace(f"{profile_name}.json") - print(f"Saved: {profile_name}.json") - - # Compile models - ref_model = torch.compile(ref_model, fullgraph=False) - model = torch.compile(model, fullgraph=False) - - print("Benchmarking MoE with FSDP2 using bf16 training") - warmup(ref_model, ref_x) - bf16_us = bench_fn_microseconds(ref_model, ref_x) - print(f"bf16 time: {bf16_us} us") - if enable_profile: - print("Profiling bf16 model") - profile_fn(ref_model, ref_x, profile_name="bf16_profile") + labels = torch.ones_like(x) - # Token group alignment size must be 16 for fp8 rowwise training - set_token_group_alignment_size_m(16) - - print("Benchmarking MoE with FSDP2 using fp8 rowwise training") - warmup(model, x) - fp8_us = bench_fn_microseconds(model, x) - print(f"fp8 time: {fp8_us} us") + # TODO: bench with fullgraph=True if/when it is supported + bf16_us = bench_fwd_bwd_microseconds( + ref_model, + ref_x, + labels=labels, + use_compile=use_compile, + fullgraph=False, + ) + print(f"BF16 time: {bf16_us} us") + if enable_profile: + print("Profiling bf16 training") + profile_fn(ref_model, ref_x, labels=labels, profile_name="bf16_profile") + + scaled_us = bench_fwd_bwd_microseconds( + model, + x, + labels=labels, + use_compile=use_compile, + fullgraph=False, + ) + print(f"Scaled time: {scaled_us} us") if enable_profile: - print("Profiling fp8 model") - profile_fn(model, x, profile_name="fp8_profile") + print("Profiling quantized training") + profile_fn(model, x, labels=labels, profile_name=f"{recipe_name}_profile") + print(f"Speedup: {bf16_us / scaled_us:.3f}x") dist.destroy_process_group() @@ -185,5 +160,15 @@ def setup_distributed(): action="store_true", help="Enable PyTorch profiling and save results to file", ) + parser.add_argument("--recipe", type=str, help="[fp8_rowwise, mxfp8]") + parser.add_argument( + "--compile", + action="store_true", + help="use torch.compile", + ) args = parser.parse_args() - bench_moe_float8_training_fsdp(enable_profile=args.profile) + bench_moe_float8_training_fsdp( + recipe_name=args.recipe, + enable_profile=args.profile, + use_compile=args.compile, + ) diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py similarity index 87% rename from benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py rename to benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py index 9b615e5b8d..e95f4293be 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py @@ -12,7 +12,7 @@ import torch from tabulate import tabulate from tqdm import tqdm -from utils import bench_fwd_bwd_microseconds +from utils import bench_fwd_bwd_microseconds, profile_fn from torchao.prototype.moe_training import _scaled_grouped_mm from torchao.prototype.moe_training.conversion_utils import MoEScalingType @@ -47,7 +47,7 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: A_shapes = [(16640, 5120)] - B_shapes = [(16, 8192, 5120), (128, 8192, 5120)] + B_shapes = [(16, 8192, 5120)] recipes = [MoEScalingType.FP8_ROWWISE] high_precision_dtypes = [torch.bfloat16] configs = [] @@ -106,6 +106,16 @@ def run_experiment( labels=labels, use_compile=args.compile, ) + if args.profile: + profile_fn( + torch._grouped_mm, + A, + B_t, + offs, + labels=labels, + use_compile=args.compile, + profile_name="bf16_profile", + ) # benchmark scaled grouped mm with dynamic fp8 rowwise quant fp8_us = bench_fwd_bwd_microseconds( @@ -117,6 +127,17 @@ def run_experiment( labels=labels, use_compile=args.compile, ) + if args.profile: + profile_fn( + _scaled_grouped_mm, + A, + B_t, + offs, + scaling_type=config.recipe, + labels=labels, + use_compile=args.compile, + profile_name="scaled_profile", + ) return ExperimentResult( bf16_us=round(bf16_us, 3), @@ -164,5 +185,6 @@ def main(args: argparse.Namespace): if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--compile", action="store_true") + arg_parser.add_argument("--profile", action="store_true") args = arg_parser.parse_args() main(args) diff --git a/benchmarks/prototype/moe_training/utils.py b/benchmarks/prototype/moe_training/utils.py index d6c5e7e82f..13f0dc9c6e 100644 --- a/benchmarks/prototype/moe_training/utils.py +++ b/benchmarks/prototype/moe_training/utils.py @@ -5,9 +5,11 @@ from torch.nn import functional as F -def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwargs): +def bench_fwd_bwd_microseconds( + fn, *args, labels=None, use_compile=False, fullgraph=True, **kwargs +): assert labels is not None - fn = torch.compile(fn, fullgraph=False) if use_compile else fn + fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn times = [] for _ in range(10): start_ns = perf_counter_ns() @@ -19,3 +21,38 @@ def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwar duration_us = (end_ns - start_ns) / 1000 times.append(duration_us) return statistics.median(times) + + +def profile_fn( + fn, + *args, + labels=None, + use_compile=False, + fullgraph=True, + profile_name="profile", + **kwargs, +): + assert labels is not None + fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn + wait, warmup, active = 1, 3, 1 + total_steps = wait + warmup + active + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=0 + ), + record_shapes=True, + with_stack=True, + ) as prof: + for _ in range(total_steps): + out = fn(*args, **kwargs) + loss = F.mse_loss(out, labels) + loss.backward() + prof.step() + + # Save profiler results + prof.export_chrome_trace(f"{profile_name}.json") + print(f"Saved: {profile_name}.json")