diff --git a/benchmarks/float8/bench_grouped_mm.py b/benchmarks/float8/bench_grouped_mm.py index 5b0bea1822..1bded14c44 100644 --- a/benchmarks/float8/bench_grouped_mm.py +++ b/benchmarks/float8/bench_grouped_mm.py @@ -64,7 +64,7 @@ def run( # Run bf16 torch._grouped_mm baseline. A = torch.randn(M, K, device=device, dtype=dtype) - B = torch.randn(E, K, N, device=device, dtype=dtype) + B = torch.randn(E, N, K, device=device, dtype=dtype) offs = generate_jagged_offs(E, M) print(f"offs: {offs}") ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks( @@ -73,7 +73,7 @@ def run( use_gpu_kernel_time, torch._grouped_mm, A, - B, + B.transpose(-2, -1), offs, ) print( @@ -84,12 +84,7 @@ def run( # Run scaled_grouped_mm. A_hp = torch.randn(M, K, device=device) - B_hp_t = ( - torch.randn(E, K, N, device=device) - .transpose(-2, -1) - .contiguous() - .transpose(-2, -1) - ) + B_hp_t = torch.randn(E, N, K, device=device).transpose(-2, -1) if recipe == "rowwise": # TODO: add e5m2 diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index d4cdfeef20..744bbcad0d 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -219,7 +219,7 @@ def get_name_to_moe_shapes_iter( N: Optional[int] = None, E: Optional[int] = None, ): - M = 8192 if M is None else M + M = 16640 if M is None else M if shape_gen_name == "llama4_17bx16e": # num_experts=16, dim=5120 names_to_shapes = { @@ -232,8 +232,8 @@ def get_name_to_moe_shapes_iter( # num_experts=128, dim=5120 names_to_shapes = { # M, K, N, E - "moe.experts.w1": (M, 5120, 8192, 128), - "moe.experts.w2": (M, 8192, 5120, 128), + "moe.experts.w1": (M, 5120, 4 * 5120, 128), + "moe.experts.w2": (M, 4 * 5120, 5120, 128), } return names_to_shapes.items() elif shape_gen_name == "custom": diff --git a/benchmarks/prototype/moe_training/benchmark_kernels.py b/benchmarks/prototype/moe_training/benchmark_kernels.py deleted file mode 100644 index d9e79c6cf3..0000000000 --- a/benchmarks/prototype/moe_training/benchmark_kernels.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py - -import itertools -from dataclasses import dataclass -from typing import List - -import torch -from tabulate import tabulate -from tqdm import tqdm -from triton.testing import do_bench - -from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, -) -from torchao.prototype.moe_training.utils import ( - torch_to_float8_per_group_colwise, - torch_to_float8_per_group_rowwise, -) - -device = torch.device("cuda") - -# Needed since changing args to function causes recompiles -torch._dynamo.config.cache_size_limit = 1000 - - -@dataclass(frozen=True) -class ExperimentConfig: - high_precision_dtype: torch.dtype - input_shape: tuple[int] - n_groups: int - - -@dataclass(frozen=True) -class ExperimentResult: - torch_time_us: float - triton_time_us: float - - -@dataclass(frozen=True) -class Experiment: - config: ExperimentConfig - result: ExperimentResult - - -def get_configs() -> List[ExperimentConfig]: - input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)] - n_groups_list = [4, 8, 16] - high_precision_dtypes = [torch.bfloat16] - configs = [] - for input_shape, n_groups, high_precision_dtype in itertools.product( - input_shapes, n_groups_list, high_precision_dtypes - ): - configs.append( - ExperimentConfig( - input_shape=input_shape, - n_groups=n_groups, - high_precision_dtype=high_precision_dtype, - ) - ) - return configs - - -def run_experiment(config: ExperimentConfig) -> ExperimentResult: - # define test inputs - input_tensor = torch.randn( - *config.input_shape, - dtype=config.high_precision_dtype, - device=device, - ) - input_row_major = input_tensor.clone().detach() - input_col_major = input_tensor.clone().detach().t() - - # - configure input to be row-major with groups divided along the column dimension, - # representing the left operand of grad_weight = grad_output_t @ input - # that occurs in the backward pass of the differentiable scaled grouped mm. - # - the transposed tensor in col-major format with groups along the row dimension, - # which represents the right operand. - group_size = input_row_major.shape[1] // config.n_groups - n_groups = config.n_groups - offs = torch.arange( - group_size, - group_size * n_groups + 1, - group_size, - device=device, - dtype=torch.int32, - ) - - def warmup(func, *args, **kwargs): - for _ in range(10): - func(*args, **kwargs) - - def run_torch( - input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor - ): - _ = torch_to_float8_per_group_rowwise( - input_row_major, - offs, - target_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - _ = torch_to_float8_per_group_colwise( - input_col_major, - offs, - target_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - - def run_triton( - input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor - ): - _ = triton_fp8_row_major_jagged_rowwise_scales( - input_row_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - _ = triton_fp8_col_major_jagged_colwise_scales( - input_col_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - - # bench torch - compiled_run_torch = torch.compile(run_torch) - torch_time_us = benchmark_cuda_function_in_microseconds( - compiled_run_torch, input_row_major, input_col_major, offs - ) - - # bench triton - warmup(run_triton, input_row_major, input_col_major, offs) - triton_time_us = benchmark_cuda_function_in_microseconds( - run_triton, input_row_major, input_col_major, offs - ) - - return ExperimentResult( - torch_time_us=torch_time_us, - triton_time_us=triton_time_us, - ) - - -def print_results(experiments: List[Experiment]): - headers = [ - "input_shape", - "n_groups", - "high_precision_dtype", - "torch_time_us", - "triton_time_us", - ] - rows = [] - for experiment in experiments: - input_shape = ( - f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})" - ) - rows.append( - [ - input_shape, - experiment.config.n_groups, - experiment.config.high_precision_dtype, - experiment.result.torch_time_us, - experiment.result.triton_time_us, - ] - ) - print(tabulate(rows, headers=headers)) - - -def benchmark_cuda_function_in_microseconds(f, *args): - return do_bench(lambda: f(*args), return_mode="median") * 1e3 - - -def main(): - torch.random.manual_seed(123) - configs = get_configs() - results = [] - for config in tqdm(configs): - result = run_experiment(config) - results.append(Experiment(config=config, result=result)) - - # Use Tabulate to print results - print_results(results) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/prototype/moe_training/benchmark_moe_layer.py b/benchmarks/prototype/moe_training/benchmark_moe_layer.py index 549aae5a5e..d18c6dc176 100644 --- a/benchmarks/prototype/moe_training/benchmark_moe_layer.py +++ b/benchmarks/prototype/moe_training/benchmark_moe_layer.py @@ -30,16 +30,18 @@ "CUDA not available or compute capability < 8.9", allow_module_level=True ) -from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) from torchao.quantization.quant_api import quantize_ -# this test requires torchtitan +# this benchmark requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -54,16 +56,15 @@ def bench_moe_float8_training_fsdp(enable_profile=False): # define model args target_fqns = ["experts"] - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=16, - dim=5120, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -82,7 +83,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig() + config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # FSDP2 @@ -90,12 +91,19 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: fully_shard(ref_model) # inputs (llama4 shapes) - batch, seq, dim = 1, 8192, 5120 + batch, seq = 1, 8192 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) x = ref_x.detach().clone().requires_grad_(True) + def warmup(model, input): + for _ in range(3): + out = model(input) + loss = F.mse_loss(out, torch.ones_like(out)) + loss.backward() + torch.cuda.synchronize() + def bench_fn_microseconds(model, input): labels = torch.ones_like(input) times = [] @@ -142,6 +150,7 @@ def profile_fn(model, input, profile_name="profile"): model = torch.compile(model, fullgraph=False) print("Benchmarking MoE with FSDP2 using bf16 training") + warmup(ref_model, ref_x) bf16_us = bench_fn_microseconds(ref_model, ref_x) print(f"bf16 time: {bf16_us} us") if enable_profile: @@ -152,6 +161,7 @@ def profile_fn(model, input, profile_name="profile"): set_token_group_alignment_size_m(16) print("Benchmarking MoE with FSDP2 using fp8 rowwise training") + warmup(model, x) fp8_us = bench_fn_microseconds(model, x) print(f"fp8 time: {fp8_us} us") if enable_profile: diff --git a/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py b/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py index d9e79c6cf3..f180bb15ac 100644 --- a/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py @@ -15,8 +15,8 @@ from triton.testing import do_bench from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( torch_to_float8_per_group_colwise, @@ -49,8 +49,8 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)] - n_groups_list = [4, 8, 16] + input_shapes = [(16640, 5120)] # (Mg, K) + n_groups_list = [16, 128] high_precision_dtypes = [torch.bfloat16] configs = [] for input_shape, n_groups, high_precision_dtype in itertools.product( @@ -114,13 +114,13 @@ def run_torch( def run_triton( input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor ): - _ = triton_fp8_row_major_jagged_rowwise_scales( + _ = triton_fp8_per_group_rowwise_scales( input_row_major, offs, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - _ = triton_fp8_col_major_jagged_colwise_scales( + _ = triton_fp8_per_group_colwise_scales( input_col_major, offs, output_dtype=torch.float8_e4m3fn, @@ -129,6 +129,7 @@ def run_triton( # bench torch compiled_run_torch = torch.compile(run_torch) + warmup(compiled_run_torch, input_row_major, input_col_major, offs) torch_time_us = benchmark_cuda_function_in_microseconds( compiled_run_torch, input_row_major, input_col_major, offs ) @@ -152,6 +153,7 @@ def print_results(experiments: List[Experiment]): "high_precision_dtype", "torch_time_us", "triton_time_us", + "triton_speedup", ] rows = [] for experiment in experiments: @@ -165,6 +167,7 @@ def print_results(experiments: List[Experiment]): experiment.config.high_precision_dtype, experiment.result.torch_time_us, experiment.result.triton_time_us, + f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", ] ) print(tabulate(rows, headers=headers)) diff --git a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py index 0cdb1c4957..53518ba491 100644 --- a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py @@ -46,8 +46,11 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - # Llama4 and DeepSeekV3 shapes - input_shapes = [(8, 4096, 1024), (16, 5120 * 4, 5120)] + # Llama4 shapes + input_shapes = [ + (16, 8192, 5120), # w1, w3 + (16, 5120, 8192), # w2 + ] high_precision_dtypes = [torch.bfloat16] configs = [] for input_shape, high_precision_dtype in itertools.product( @@ -84,12 +87,13 @@ def run_torch(input_tensor: torch.Tensor): return out def run_triton(input_tensor: torch.Tensor): - _ = triton_fp8_rowwise_3d_transpose_rhs( + out = triton_fp8_rowwise_3d_transpose_rhs( input_tensor, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) torch.cuda.synchronize() + return out # bench torch compiled_run_torch = torch.compile(run_torch) @@ -117,6 +121,7 @@ def print_results(experiments: List[Experiment]): "input_shape", "torch_time_us", "triton_time_us", + "triton_speedup", ] rows = [] for experiment in experiments: @@ -126,6 +131,7 @@ def print_results(experiments: List[Experiment]): input_shape, experiment.result.torch_time_us, experiment.result.triton_time_us, + f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", ] ) print(tabulate(rows, headers=headers)) diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py index c229eaeb71..120a859355 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py @@ -6,15 +6,16 @@ # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py import argparse import itertools -import time from dataclasses import dataclass from typing import List import torch from tabulate import tabulate from tqdm import tqdm +from utils import bench_fwd_bwd_microseconds from torchao.prototype.moe_training import _scaled_grouped_mm +from torchao.prototype.moe_training.conversion_utils import MoEScalingType device = torch.device("cuda") @@ -27,11 +28,14 @@ class ExperimentConfig: high_precision_dtype: torch.dtype A_shape: tuple[int] B_shape: tuple[int] + recipe: MoEScalingType @dataclass(frozen=True) class ExperimentResult: - time_us: float + bf16_us: float + fp8_us: float + fp8_speedup: float @dataclass(frozen=True) @@ -41,19 +45,22 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)] - B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)] + A_shapes = [(16640, 5120)] + B_shapes = [(16, 8192, 5120)] + recipes = [MoEScalingType.FP8_ROWWISE] high_precision_dtypes = [torch.bfloat16] configs = [] - for A_shape, B_shape, high_precision_dtype in itertools.product( + for A_shape, B_shape, recipe, high_precision_dtype in itertools.product( A_shapes, B_shapes, + recipes, high_precision_dtypes, ): configs.append( ExperimentConfig( A_shape=A_shape, B_shape=B_shape, + recipe=recipe, high_precision_dtype=high_precision_dtype, ) ) @@ -92,30 +99,35 @@ def run_experiment( dtype=torch.int32, ) - def warmup(func, *args, **kwargs): - for _ in range(10): - func(*args, **kwargs) + labels = torch.ones( + (A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16 + ) - def forward_backward(A, B_t, offs): - out = _scaled_grouped_mm( - A, - B_t, - offs=offs, - out_dtype=torch.bfloat16, - ) - out.sum().backward() - torch.cuda.synchronize() + # benchmark bf16 grouped mm + bf16_us = bench_fwd_bwd_microseconds( + torch._grouped_mm, + A, + B_t, + offs, + labels=labels, + use_compile=args.compile, + ) - # benchmark torch - torch_func = torch.compile(forward_backward) if args.compile else forward_backward - warmup(torch_func, A, B_t, offs) - start_time_ns = time.perf_counter_ns() - torch_func(A, B_t, offs) - torch_time_ns = time.perf_counter_ns() - start_time_ns - time_us = torch_time_ns / 1e3 + # benchmark scaled grouped mm with dynamic fp8 rowwise quant + fp8_us = bench_fwd_bwd_microseconds( + _scaled_grouped_mm, + A, + B_t, + offs, + scaling_type=config.recipe, + labels=labels, + use_compile=args.compile, + ) return ExperimentResult( - time_us=round(time_us, 3), + bf16_us=round(bf16_us, 3), + fp8_us=round(fp8_us, 3), + fp8_speedup=round(bf16_us / fp8_us, 3), ) @@ -123,7 +135,9 @@ def print_results(experiments: List[Experiment]): headers = [ "A_shape", "B_shape", - "time_us", + "bf16_time_us", + "fp8_time_us", + "fp8_speedup", ] rows = [] for experiment in experiments: @@ -133,7 +147,9 @@ def print_results(experiments: List[Experiment]): [ A_shape, B_shape, - experiment.result.time_us, + experiment.result.bf16_us, + experiment.result.fp8_us, + f"{experiment.result.fp8_speedup}x", ] ) print(tabulate(rows, headers=headers)) diff --git a/benchmarks/prototype/moe_training/utils.py b/benchmarks/prototype/moe_training/utils.py new file mode 100644 index 0000000000..d6c5e7e82f --- /dev/null +++ b/benchmarks/prototype/moe_training/utils.py @@ -0,0 +1,21 @@ +import statistics +from time import perf_counter_ns + +import torch +from torch.nn import functional as F + + +def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwargs): + assert labels is not None + fn = torch.compile(fn, fullgraph=False) if use_compile else fn + times = [] + for _ in range(10): + start_ns = perf_counter_ns() + out = fn(*args, **kwargs) + loss = F.mse_loss(out, labels) + loss.backward() + torch.cuda.synchronize() + end_ns = perf_counter_ns() + duration_us = (end_ns - start_ns) / 1000 + times.append(duration_us) + return statistics.median(times) diff --git a/test/prototype/moe_training/test_fsdp.py b/test/prototype/moe_training/test_fsdp.py index 69c15e2253..b205675527 100644 --- a/test/prototype/moe_training/test_fsdp.py +++ b/test/prototype/moe_training/test_fsdp.py @@ -35,8 +35,10 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.distributed.expert_parallel import ( + set_token_group_alignment_size_m, + ) + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -49,18 +51,20 @@ def test_moe_float8_training_fsdp(): # setup distributed for fsdp setup_distributed() + # token group aligment size must be 16 for fp8 + set_token_group_alignment_size_m(16) + # define model args target_fqns = ["experts"] - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -93,7 +97,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: fully_shard(ref_model) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -105,7 +109,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + min_out_sqnr = 29.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -118,15 +125,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 30.0, ( - f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + min_input_grad_sqnr = 29.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) dist.destroy_process_group() diff --git a/test/prototype/moe_training/test_fsdp_tp.py b/test/prototype/moe_training/test_fsdp_tp.py index 083d9de1b9..4a7c1356c0 100644 --- a/test/prototype/moe_training/test_fsdp_tp.py +++ b/test/prototype/moe_training/test_fsdp_tp.py @@ -49,14 +49,14 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, NoParallel, TensorParallel, + set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -74,21 +74,22 @@ def test_moe_float8_training_fsdp_tp(target_fqns: list[str]): assert torch.cuda.is_available() + # token group aligment size must be 16 for fp8 + set_token_group_alignment_size_m(16) + # setup distributed for tp mesh = setup_distributed() # define model args - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, - vocab_size=1024, ) + dim, hidden_dim = 5120, 4 * 5120 init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(1) ref_model.init_weights(init_std, device) @@ -146,7 +147,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -158,7 +159,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + min_out_sqnr = 30.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -171,15 +175,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 28.0, ( - f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}." + min_input_grad_sqnr = 28.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) dist.destroy_process_group() diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index b24b61be8c..ea4afa5c90 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -23,8 +23,8 @@ triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( _is_column_major, @@ -52,7 +52,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): target_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=round_scales_to_power_of_2, ) - kernel_fp8_data, kernel_scales = triton_fp8_row_major_jagged_rowwise_scales( + kernel_fp8_data, kernel_scales = triton_fp8_per_group_rowwise_scales( x, colwise_offs, output_dtype=torch.float8_e4m3fn, @@ -80,7 +80,7 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo target_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=round_scales_to_power_of_2, ) - kernel_fp8_data, kernel_scales = triton_fp8_col_major_jagged_colwise_scales( + kernel_fp8_data, kernel_scales = triton_fp8_per_group_colwise_scales( x, rowwise_offs, output_dtype=torch.float8_e4m3fn, diff --git a/test/prototype/moe_training/test_tp.py b/test/prototype/moe_training/test_tp.py index 46ba544791..bf913a69b3 100644 --- a/test/prototype/moe_training/test_tp.py +++ b/test/prototype/moe_training/test_tp.py @@ -49,14 +49,14 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, NoParallel, TensorParallel, + set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -74,21 +74,22 @@ def test_moe_float8_training_tp(target_fqns: list[str]): assert torch.cuda.is_available() + # token group aligment size must be 16 for fp8 + set_token_group_alignment_size_m(16) + # setup distributed for tp mesh = setup_distributed() # define model args - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, - vocab_size=1024, ) + dim, hidden_dim = 5120, 4 * 5120 init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(1) ref_model.init_weights(init_std, device) @@ -141,7 +142,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -153,7 +154,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + min_out_sqnr = 29.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -166,15 +170,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 28.0, ( - f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}." + min_input_grad_sqnr = 28.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) dist.destroy_process_group() @@ -203,7 +209,7 @@ def apply_moe_ep_tp( moe_layer_plan = { # input / output sharding on the seqlen dim # all-gather for input, reduce-scatter for output - "moe": PrepareModuleInputOutput( + "": PrepareModuleInputOutput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), use_local_input=True, @@ -211,9 +217,9 @@ def apply_moe_ep_tp( desired_output_layouts=(Shard(1),), ), # replicate computation for the router - "moe.router.gate": NoParallel(), + "router.gate": NoParallel(), # input Replicate, output Partial - "moe.shared_expert": TensorParallel(), + "shared_expert": TensorParallel(), } parallelize_module( module=model, diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index d08f218842..98f9fb266a 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -22,11 +22,10 @@ # this test requires torchtitan try: - from torchtitan.experiments.llama4.infra.expert_parallel import ( + from torchtitan.distributed.expert_parallel import ( set_token_group_alignment_size_m, ) - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.models.moe import MoE, MoEArgs except ImportError: pytest.skip( "torchtitan not installed, skipping MoE tests.", allow_module_level=True @@ -47,16 +46,15 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool): # has the contraction dim be divisible by 16. 16 byte alignment is required # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements. set_token_group_alignment_size_m(16) - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -75,7 +73,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) + config = MoETrainingConfig() quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted @@ -83,14 +81,13 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: model, target_fqns=target_fqns, ) - if compile: # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it model = torch.compile(model, fullgraph=False) ref_model = torch.compile(ref_model, fullgraph=False) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -124,7 +121,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # validate param gradients - min_param_grad_sqnr = 25.0 + min_param_grad_sqnr = 23.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( @@ -145,18 +142,15 @@ def test_moe_mxfp8_training(target_fqns: list[str]): # Token groups must be divisible by 32 for mxfp8 set_token_group_alignment_size_m(block_size) - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, - multiple_of=block_size, - ffn_dim_multiplier=1.0, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 256, 4 * 256 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -185,7 +179,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: ) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) diff --git a/torchao/prototype/moe_training/kernels/__init__.py b/torchao/prototype/moe_training/kernels/__init__.py index 8fb16579e5..0b88cc08a2 100644 --- a/torchao/prototype/moe_training/kernels/__init__.py +++ b/torchao/prototype/moe_training/kernels/__init__.py @@ -2,8 +2,8 @@ triton_fp8_rowwise_3d_transpose_rhs as triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales, + triton_fp8_per_group_colwise_scales as triton_fp8_per_group_colwise_scales, ) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales, ) diff --git a/torchao/prototype/moe_training/kernels/float8_rowwise.py b/torchao/prototype/moe_training/kernels/float8_rowwise.py index 9d7a7768d4..3449b89336 100644 --- a/torchao/prototype/moe_training/kernels/float8_rowwise.py +++ b/torchao/prototype/moe_training/kernels/float8_rowwise.py @@ -29,7 +29,7 @@ block_sizes_n = [32, 128, 512] # large dim (output_features) block_sizes_k = [32, 128, 512] # small dim (input_features) num_warps = [8] -num_stages = [2, 3] +num_stages = [2, 4] kernel_configs_2D = [ triton.Config( {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}, @@ -42,10 +42,8 @@ for stages in num_stages ] -from torch.library import triton_op, wrap_triton - -@triton_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={}) +@torch.library.custom_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={}) def triton_fp8_rowwise_3d_transpose_rhs( hp_tensor: torch.Tensor, # (E, K, N) output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -80,7 +78,7 @@ def triton_fp8_rowwise_3d_transpose_rhs( ) # compute scales - wrap_triton(_triton_fp8_rowwise_3d_transpose_scales_rhs_kernel)[grid]( + _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel[grid]( hp_tensor, hp_tensor.stride(0), hp_tensor.stride(1), @@ -100,7 +98,7 @@ def triton_fp8_rowwise_3d_transpose_rhs( ) # perform casting - wrap_triton(_triton_fp8_rowwise_3d_transpose_cast_rhs_kernel)[grid]( + _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel[grid]( hp_tensor, hp_tensor.stride(0), hp_tensor.stride(1), @@ -124,6 +122,22 @@ def triton_fp8_rowwise_3d_transpose_rhs( return output_buffer, scales_buffer +@triton_fp8_rowwise_3d_transpose_rhs.register_fake +def _fake_triton_fp8_rowwise_3d_transpose_rhs( + hp_tensor: torch.Tensor, # (E, K, N) + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 3, "input tensor must be 3D" + e, k, n = hp_tensor.shape + output_buffer = torch.empty( + (e, n, k), dtype=output_dtype, device=hp_tensor.device + ).as_strided((e, n, k), (n * k, 1, n)) + + scales_buffer = torch.empty((e, k), dtype=torch.float32, device=hp_tensor.device) + return output_buffer, scales_buffer + + @triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) @triton.jit def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel( diff --git a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py index ff0b11acba..16f4bf87f4 100644 --- a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py +++ b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py @@ -32,9 +32,9 @@ } block_sizes = [1, 16, 32, 64] -block_sizes_iter = [32, 64, 128, 256] -num_warps = [1, 4] -num_stages = [2, 3] +block_sizes_iter = [64, 128, 256] +num_warps = [4] +num_stages = [3] kernel_configs_2D = [ triton.Config( {"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter}, @@ -47,11 +47,11 @@ for stages in num_stages ] -from torch.library import triton_op, wrap_triton - -@triton_op("torchao::triton_fp8_row_major_jagged_rowwise_scales", mutates_args={}) -def triton_fp8_row_major_jagged_rowwise_scales( +@torch.library.custom_op( + "torchao::triton_fp8_per_group_rowwise_scales", mutates_args={} +) +def triton_fp8_per_group_rowwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -95,7 +95,7 @@ def triton_fp8_row_major_jagged_rowwise_scales( triton.cdiv(m, meta["BLOCK_SIZE"]), offsets.numel(), ) - wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid]( + _triton_fp8_per_group_rowwise_scales_kernel[grid]( hp_tensor, offsets, output_buffer, @@ -117,6 +117,24 @@ def triton_fp8_row_major_jagged_rowwise_scales( return output_buffer, scales_buffer +@triton_fp8_per_group_rowwise_scales.register_fake +def _fake_triton_fp8_per_group_rowwise_scales_kernel( + hp_tensor: torch.Tensor, + offsets: torch.Tensor, + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 2, "input tensor must be 2D" + m, k = hp_tensor.shape + n_groups = offsets.numel() + output = torch.empty_like(hp_tensor, dtype=output_dtype).as_strided( + (m, k), # shape + (k, 1), # stride + ) + scales = torch.empty((m * n_groups), dtype=torch.float32, device=hp_tensor.device) + return output, scales + + # This kernel is used on grad_output.t() which has shape (K, M), # before the calculation `grad_B = grad_output_t @ input`. # However, in this code, we use the conventional dim names (M, K) @@ -125,7 +143,7 @@ def triton_fp8_row_major_jagged_rowwise_scales( # to recompile on `token` dim (K, in this case) changes. @triton.autotune(configs=kernel_configs_2D, key=["M"]) @triton.jit -def _triton_fp8_row_major_jagged_rowwise_scales( +def _triton_fp8_per_group_rowwise_scales_kernel( input_ptr, offsets_ptr, out_ptr, @@ -215,8 +233,10 @@ def _triton_fp8_row_major_jagged_rowwise_scales( tl.store(out_ptr + out_offs, fp8_data, mask=block_mask) -@triton_op("torchao::triton_fp8_col_major_jagged_colwise_scales", mutates_args={}) -def triton_fp8_col_major_jagged_colwise_scales( +@torch.library.custom_op( + "torchao::triton_fp8_per_group_colwise_scales", mutates_args={} +) +def triton_fp8_per_group_colwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -263,7 +283,7 @@ def triton_fp8_col_major_jagged_colwise_scales( triton.cdiv(n, meta["BLOCK_SIZE"]), offsets.numel(), ) - wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid]( + _triton_fp8_per_group_colwise_scales_kernel[grid]( hp_tensor, offsets, output_buffer, @@ -285,13 +305,33 @@ def triton_fp8_col_major_jagged_colwise_scales( return output_buffer, scales_buffer +@triton_fp8_per_group_colwise_scales.register_fake +def _fake_triton_fp8_per_group_colwise_scales( + hp_tensor: torch.Tensor, + offsets: torch.Tensor, + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 2, "input tensor must be 2D" + k, n = hp_tensor.shape + n_groups = offsets.numel() + output_buffer = torch.empty_like( + hp_tensor, dtype=output_dtype, device=hp_tensor.device + ).as_strided(hp_tensor.size(), (1, k)) + + scales_buffer = torch.empty( + (n * n_groups), dtype=torch.float32, device=hp_tensor.device + ) + return output_buffer, scales_buffer + + # This kernel is used on `input` which has shape (M, K), # before the calculation `grad_B = grad_output_t @ input`. # The tokens per expert will vary per iteration, so don't want # to recompile on `token` dim (M) changes. @triton.autotune(configs=kernel_configs_2D, key=["K"]) @triton.jit -def _triton_fp8_col_major_jagged_colwise_scales( +def _triton_fp8_per_group_colwise_scales_kernel( input_ptr, offsets_ptr, out_ptr, diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 7dc246e251..0ee72ea35b 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -13,8 +13,8 @@ from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.kernels import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, triton_fp8_rowwise_3d_transpose_rhs, ) from torchao.prototype.moe_training.utils import ( @@ -48,7 +48,7 @@ def _scaled_grouped_mm( """ # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. if scaling_type == MoEScalingType.FP8_ROWWISE: - logger.info("Using fp8 rowwise scaled_grouped_mm") + # print("Using fp8 rowwise scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -56,7 +56,7 @@ def _scaled_grouped_mm( out_dtype, ) elif scaling_type == MoEScalingType.MXFP8: - logger.info("Using mxfp8 scaled_grouped_mm") + print("Using mxfp8 scaled_grouped_mm") block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow? return _MXFP8GroupedMM.apply( A, @@ -140,17 +140,8 @@ def forward( B_t_scaled = B_t.to(torch.float32) * B_t_scales B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) - # Precompute non-transposed B column-major for backward, to save memory by storing the - # low precision B tensor instead of the high precision B tensor. - # In the backward this is needed for grad_A: grad_output @ B. - B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( - B_t, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - # Store what we need for backward. - ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) + ctx.save_for_backward(A, B_t, offs) ctx.out_dtype = out_dtype # Perform scaled grouped GEMM and return result. @@ -179,7 +170,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): - A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors + A, B_t, offs = ctx.saved_tensors out_dtype = ctx.out_dtype # Convert grad_output to float8, row-major for left operand of grouped GEMM @@ -199,6 +190,14 @@ def backward(ctx, grad_output: torch.Tensor): grad_output_scaled, torch.float8_e4m3fn ) + # Compute B fp8 column-major for right operand of grouped GEMM: + # grad_A = grad_output @ B. + B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( + B_t._data if hasattr(B_t, "_data") else B_t, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + # Compute grad_A. # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) @@ -217,8 +216,8 @@ def backward(ctx, grad_output: torch.Tensor): grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, - grad_output_scales.squeeze().reciprocal(), - B_scales.squeeze().reciprocal(), + grad_output_scales.reciprocal(), + B_scales.reciprocal(), offs, out_dtype=out_dtype, use_fast_accum=True, @@ -230,7 +229,7 @@ def backward(ctx, grad_output: torch.Tensor): # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM # needed for grad_B: grad_output_t @ A grad_output_t_fp8_row_major, grad_output_t_scales = ( - triton_fp8_row_major_jagged_rowwise_scales( + triton_fp8_per_group_rowwise_scales( grad_output.transpose(-2, -1), offs, torch.float8_e4m3fn, @@ -238,7 +237,7 @@ def backward(ctx, grad_output: torch.Tensor): ) ) - A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( + A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales( A, offs, torch.float8_e4m3fn, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 1ddd098675..a861aa6533 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -97,9 +97,12 @@ def __torch_function__(cls, func, types, args, kwargs={}): A_is_2d = A.dim() == 2 B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None + other_args = args[2:] if A_is_2d and B_is_3d and has_offs: return _scaled_grouped_mm( - *args, + A, + B, + *other_args, scaling_type=scaling_type, **kwargs, ) @@ -111,16 +114,25 @@ def __torch_function__(cls, func, types, args, kwargs={}): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs={}): - # detach is special case - scaling_type = args[0].scaling_type - if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0]._data, scaling_type) + # unwrap args/kwargs and extract scaling_type + scaling_type = None + + def unwrap(t): + nonlocal scaling_type + if scaling_type is None: + scaling_type = t.scaling_type + else: + assert t.scaling_type == scaling_type + return t._data - # unwrap args/kwargs - unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x args, kwargs = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) + assert scaling_type is not None + + # detach is special case + if func == torch.ops.aten.detach.default: + return ScaledGroupedMMTensor(args[0], scaling_type) # perform op out = func(*args, **kwargs)