diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 0049007d27..43cf5ecb0a 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -219,9 +219,7 @@ def compute_reference_forward( return output_ref -@pytest.mark.parametrize("M", (1024, 4096)) -@pytest.mark.parametrize("K", (1024, 4096)) -@pytest.mark.parametrize("N", (1024, 4096)) +@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) @pytest.mark.parametrize("num_experts", (1, 8, 16)) def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") @@ -249,3 +247,23 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts): sqnr = compute_error(ref_out, out) min_sqnr = 27.0 assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" + + +@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) +@pytest.mark.parametrize("num_experts", (1, 8, 16)) +def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts): + from torchao.prototype.moe_training.scaled_grouped_mm import ( + _MXFP8GroupedMM, + ) + + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda") + offs = generate_jagged_offs(num_experts, M) + x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone() + block_size = 32 + + out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16) + ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16) + sqnr = compute_error(ref_out, out) + min_sqnr = 27.0 + assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index b12a3d954f..66afecc9cb 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional +from typing import Optional, Tuple import torch @@ -18,6 +18,7 @@ from torchao.prototype.moe_training.utils import ( _is_column_major, ) +from torchao.prototype.mx_formats.mx_tensor import to_mx logger: logging.Logger = logging.getLogger(__name__) @@ -268,6 +269,81 @@ def backward(ctx, grad_output: torch.Tensor): return grad_A, grad_B.transpose(-2, -1), None, None, None, None +class _MXFP8GroupedMM(torch.autograd.Function): + """Differentiable implementation of grouped GEMM with dynamic mxpf8 quantization.""" + + @staticmethod + def forward( + ctx, + A: torch.Tensor, + B_t: torch.Tensor, + offs: Optional[torch.Tensor] = None, + block_size: int = 32, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + emulated: bool = True, + ) -> torch.Tensor: + # torchao _scaled_grouped_mm only supports A=2D and B=3D. + assert A.ndim == 2, "A must be 2D" + assert B_t.ndim == 3, "B must be 3D" + assert block_size == 32, "Only block_size=32 is supported" + assert emulated, "Only emulated mxfp8 grouped gemm is supported" + + # Cast to mxpf8 across dim -1. + # A_mx shape: (M, K) + # A_scale shape: (M, K//block_size) + A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + + # Cast B_t per-expert to mxfp8 across dim1. + # B_t_mx shape: (E, K, N) + # B_t_scale shape: (E, K//block_size, N) + B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(B_t, block_size=block_size) + + # Store what we need for backward. + ctx.save_for_backward(A, B_t, offs) + ctx.out_dtype = out_dtype + + # Perform scaled grouped GEMM and return result. + # output = input @ weight.T + # output shape: (M, N) + out = emulated_mxfp8_scaled_grouped_mm( + A_mx, + A_scale, + B_t_mx, + B_t_scale, + offs=offs, + block_size=block_size, + out_dtype=out_dtype, + ) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + raise NotImplementedError + + +def _to_mxfp8_3d_expert_weights_dim1( + w_t: torch.Tensor, # (num_experts, K, N) + block_size: int = 32, + elem_dtype: torch.dtype = torch.float8_e4m3fn, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert a 3D tensor of shape (experts, K, N) to MXFP8 format along dim1. + Args: + x (torch.Tensor): Input tensor to be converted. + block_size (int): Block size for MXFP8 quantization. + elem_dtype (torch.dtype): Element dtype for MXFP8 quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: Converted tensor and scale tensor. + - scale shape: (expets, K // block_size, N) + - output shape: (experts, K, N) + """ + # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose. + w_scale, w_mx = to_mx( + w_t.transpose(-2, -1).contiguous(), elem_dtype=elem_dtype, block_size=block_size + ) + w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1) + return w_t_scale, w_t_mx + + def emulated_mxfp8_scaled_grouped_mm( A_mx: torch.Tensor, A_scale: torch.Tensor,