diff --git a/benchmarks/float8/bench_grouped_mm.py b/benchmarks/float8/bench_grouped_mm.py index b43a9f0574..5b0bea1822 100644 --- a/benchmarks/float8/bench_grouped_mm.py +++ b/benchmarks/float8/bench_grouped_mm.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import random from typing import Optional import fire @@ -11,6 +10,7 @@ import torch from utils import do_benchmarks, get_name_to_moe_shapes_iter +from torchao.prototype.moe_training.utils import generate_jagged_offs from torchao.testing.training.roofline_utils import get_specs @@ -146,39 +146,6 @@ def do_scaled_grouped_mm(A, B): data_df.to_csv(out_filename) -def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"): - """ - Generates a tensor of length E, containing random values divisible by 16, - from 0 to M, in sorted order, and where the final value in the tensor is always M. - Args: - E (int): The length of the tensor. - M (int): The maximum value in the tensor. - Returns: - torch.Tensor: A tensor of length E with the specified properties. - """ - # Ensure M is divisible by 16 - if M % 16 != 0: - raise ValueError("M must be divisible by 16") - - # Generate a list of possible values - possible_values = [i for i in range(0, M + 1, 16)] - - # If E is larger than the number of possible values, raise an error - if E > len(possible_values): - raise ValueError("E cannot be larger than the number of possible values") - - # Randomly select E - 1 values from the possible values (excluding M) - selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1)) - - # Append M to the selected values - selected_values = torch.cat((selected_values, torch.tensor([M]))) - - # Sort the selected values - selected_values, _ = torch.sort(selected_values) - - return selected_values.to(dtype).to(device) - - def main() -> None: fire.Fire(run) diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 3b4d23965b..0049007d27 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -7,6 +7,9 @@ import pytest import torch +pytest.importorskip("triton", reason="Triton required to run this test") + +from torchao.prototype.moe_training.utils import generate_jagged_offs from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 # We need to skip before doing any imports which would use triton, since @@ -25,10 +28,12 @@ ) from torchao.float8.float8_linear import matmul_with_hp_or_float8_args from torchao.float8.float8_training_tensor import LinearMMConfig -from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated from torchao.prototype.moe_training.scaled_grouped_mm import ( _scaled_grouped_mm, + emulated_mxfp8_scaled_grouped_mm, ) +from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.testing.utils import skip_if_rocm @@ -212,3 +217,35 @@ def compute_reference_forward( # Concatenate the outputs and verify the full result is correct. output_ref = torch.cat(outputs, dim=0) return output_ref + + +@pytest.mark.parametrize("M", (1024, 4096)) +@pytest.mark.parametrize("K", (1024, 4096)) +@pytest.mark.parametrize("N", (1024, 4096)) +@pytest.mark.parametrize("num_experts", (1, 8, 16)) +def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts): + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda") + offs = generate_jagged_offs(num_experts, M) + x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone() + + # Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm + block_size = 32 + x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + + # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose. + w_scale, w_mx = to_mx( + w_t.transpose(-2, -1).contiguous(), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1) + + ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16) + out = emulated_mxfp8_scaled_grouped_mm( + x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16 + ) + + sqnr = compute_error(ref_out, out) + min_sqnr = 27.0 + assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index d9ccdcba03..b12a3d954f 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -217,7 +217,7 @@ def backward(ctx, grad_output: torch.Tensor): use_fast_accum=True, ) - # Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM + # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM # needed for grad_B: grad_output_t @ A grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous() @@ -266,3 +266,58 @@ def backward(ctx, grad_output: torch.Tensor): use_fast_accum=True, ) return grad_A, grad_B.transpose(-2, -1), None, None, None, None + + +def emulated_mxfp8_scaled_grouped_mm( + A_mx: torch.Tensor, + A_scale: torch.Tensor, + B_t_mx: torch.Tensor, + B_t_scale: torch.Tensor, + offs: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + block_size: int = 32, +) -> torch.Tensor: + # Dequantize input + # A_mx shape: (M, K) + # A_scale shape: (M, K//block_size) + A_orig_shape = A_mx.shape + + # Reshape to be able to do per-scaling group multiplication + # A_mx shape: (M, K//block_size, block_size) + # A_scale shape: (M, K//block_size, 1) + A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size) + A_scale = A_scale.unsqueeze(-1) + + # Rescale and cast to bfloat16 + A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16) + + # Reshape back to original shape + # A shape: (M, K) + A = A.reshape(A_orig_shape) + + # Dequantize weights + # B_t_mx shape: (E, K, N) + # B_t_scale shape: (E, K//block_size, N) + E, K, N = B_t_mx.shape + + # Tranpose to get block_size on rightmost dim + # B_mx shape: (E, N, K) + # B_scale shape: (E, N, K//block_size) + B_mx, B_scale = B_t_mx.transpose(-2, -1), B_t_scale.transpose(-2, -1) + + # Reshape to be able to do per-scaling group multiplication + # B_mx shape: (E, N, K//block_size, block_size) + # B_scale shape: (E, N, K//block_size, 1) + B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size) + B_scale = B_scale.unsqueeze(-1) + + # Rescale and cast to bfloat16 + B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16) + + # Reshape back to original shape + # B shape: (E, K, N) + B_t = B.reshape(E, N, K).transpose(-2, -1) + + # Perform bf16 grouped GEMM. + out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype) + return out diff --git a/torchao/prototype/moe_training/utils.py b/torchao/prototype/moe_training/utils.py index 038c379d62..225bb1b3f8 100644 --- a/torchao/prototype/moe_training/utils.py +++ b/torchao/prototype/moe_training/utils.py @@ -1,3 +1,4 @@ +import random from typing import Tuple import torch @@ -154,3 +155,38 @@ def _is_column_major(x: torch.Tensor) -> bool: """ assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" return x.stride(-2) == 1 and x.stride(-1) > 1 + + +def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"): + """ + Utility function for tests and benchmarks. + + Generates a tensor of length E, containing random values divisible by 16, + from 0 to M, in sorted order, and where the final value in the tensor is always M. + Args: + E (int): The length of the tensor. + M (int): The maximum value in the tensor. + Returns: + torch.Tensor: A tensor of length E with the specified properties. + """ + # Ensure M is divisible by 16 + if M % 16 != 0: + raise ValueError("M must be divisible by 16") + + # Generate a list of possible values + possible_values = [i for i in range(0, M + 1, 16)] + + # If E is larger than the number of possible values, raise an error + if E > len(possible_values): + raise ValueError("E cannot be larger than the number of possible values") + + # Randomly select E - 1 values from the possible values (excluding M) + selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1)) + + # Append M to the selected values + selected_values = torch.cat((selected_values, torch.tensor([M]))) + + # Sort the selected values + selected_values, _ = torch.sort(selected_values) + + return selected_values.to(dtype).to(device)