Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions benchmarks/float8/bench_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import Optional

import fire
import pandas as pd
import torch
from utils import do_benchmarks, get_name_to_moe_shapes_iter

from torchao.prototype.moe_training.utils import generate_jagged_offs
from torchao.testing.training.roofline_utils import get_specs


Expand Down Expand Up @@ -146,39 +146,6 @@ def do_scaled_grouped_mm(A, B):
data_df.to_csv(out_filename)


def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
"""
Generates a tensor of length E, containing random values divisible by 16,
from 0 to M, in sorted order, and where the final value in the tensor is always M.
Args:
E (int): The length of the tensor.
M (int): The maximum value in the tensor.
Returns:
torch.Tensor: A tensor of length E with the specified properties.
"""
# Ensure M is divisible by 16
if M % 16 != 0:
raise ValueError("M must be divisible by 16")

# Generate a list of possible values
possible_values = [i for i in range(0, M + 1, 16)]

# If E is larger than the number of possible values, raise an error
if E > len(possible_values):
raise ValueError("E cannot be larger than the number of possible values")

# Randomly select E - 1 values from the possible values (excluding M)
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))

# Append M to the selected values
selected_values = torch.cat((selected_values, torch.tensor([M])))

# Sort the selected values
selected_values, _ = torch.sort(selected_values)

return selected_values.to(dtype).to(device)


def main() -> None:
fire.Fire(run)

Expand Down
39 changes: 38 additions & 1 deletion test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import pytest
import torch

pytest.importorskip("triton", reason="Triton required to run this test")

from torchao.prototype.moe_training.utils import generate_jagged_offs
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

# We need to skip before doing any imports which would use triton, since
Expand All @@ -25,10 +28,12 @@
)
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
from torchao.float8.float8_training_tensor import LinearMMConfig
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.scaled_grouped_mm import (
_scaled_grouped_mm,
emulated_mxfp8_scaled_grouped_mm,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.testing.utils import skip_if_rocm


Expand Down Expand Up @@ -212,3 +217,35 @@ def compute_reference_forward(
# Concatenate the outputs and verify the full result is correct.
output_ref = torch.cat(outputs, dim=0)
return output_ref


@pytest.mark.parametrize("M", (1024, 4096))
@pytest.mark.parametrize("K", (1024, 4096))
@pytest.mark.parametrize("N", (1024, 4096))
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M)
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()

# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
block_size = 32
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)

# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
w_scale, w_mx = to_mx(
w_t.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)

ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
out = emulated_mxfp8_scaled_grouped_mm(
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
)

sqnr = compute_error(ref_out, out)
min_sqnr = 27.0
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
57 changes: 56 additions & 1 deletion torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def backward(ctx, grad_output: torch.Tensor):
use_fast_accum=True,
)

# Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_B: grad_output_t @ A
grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous()

Expand Down Expand Up @@ -266,3 +266,58 @@ def backward(ctx, grad_output: torch.Tensor):
use_fast_accum=True,
)
return grad_A, grad_B.transpose(-2, -1), None, None, None, None


def emulated_mxfp8_scaled_grouped_mm(
A_mx: torch.Tensor,
A_scale: torch.Tensor,
B_t_mx: torch.Tensor,
B_t_scale: torch.Tensor,
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
block_size: int = 32,
) -> torch.Tensor:
# Dequantize input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: claude might have gotten a little wordy w/ this one on the comments

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hand wrote all of this actually, including the comments lol. Maybe I am old fashioned but I still hand write everything. Maybe cause I only use the free-tiers but I get pretty terrible results asking AI tools to help with this kind of work. I use it for debugging assistance though and the AI autocomplete is pretty good, it makes beautiful docstrings

# A_mx shape: (M, K)
# A_scale shape: (M, K//block_size)
A_orig_shape = A_mx.shape

# Reshape to be able to do per-scaling group multiplication
# A_mx shape: (M, K//block_size, block_size)
# A_scale shape: (M, K//block_size, 1)
A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size)
A_scale = A_scale.unsqueeze(-1)

# Rescale and cast to bfloat16
A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16)

# Reshape back to original shape
# A shape: (M, K)
A = A.reshape(A_orig_shape)

# Dequantize weights
# B_t_mx shape: (E, K, N)
# B_t_scale shape: (E, K//block_size, N)
E, K, N = B_t_mx.shape

# Tranpose to get block_size on rightmost dim
# B_mx shape: (E, N, K)
# B_scale shape: (E, N, K//block_size)
B_mx, B_scale = B_t_mx.transpose(-2, -1), B_t_scale.transpose(-2, -1)

# Reshape to be able to do per-scaling group multiplication
# B_mx shape: (E, N, K//block_size, block_size)
# B_scale shape: (E, N, K//block_size, 1)
B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size)
B_scale = B_scale.unsqueeze(-1)

# Rescale and cast to bfloat16
B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16)

# Reshape back to original shape
# B shape: (E, K, N)
B_t = B.reshape(E, N, K).transpose(-2, -1)

# Perform bf16 grouped GEMM.
out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype)
return out
36 changes: 36 additions & 0 deletions torchao/prototype/moe_training/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import Tuple

import torch
Expand Down Expand Up @@ -154,3 +155,38 @@ def _is_column_major(x: torch.Tensor) -> bool:
"""
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
return x.stride(-2) == 1 and x.stride(-1) > 1


def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
"""
Utility function for tests and benchmarks.

Generates a tensor of length E, containing random values divisible by 16,
from 0 to M, in sorted order, and where the final value in the tensor is always M.
Args:
E (int): The length of the tensor.
M (int): The maximum value in the tensor.
Returns:
torch.Tensor: A tensor of length E with the specified properties.
"""
# Ensure M is divisible by 16
if M % 16 != 0:
raise ValueError("M must be divisible by 16")

# Generate a list of possible values
possible_values = [i for i in range(0, M + 1, 16)]

# If E is larger than the number of possible values, raise an error
if E > len(possible_values):
raise ValueError("E cannot be larger than the number of possible values")

# Randomly select E - 1 values from the possible values (excluding M)
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))

# Append M to the selected values
selected_values = torch.cat((selected_values, torch.tensor([M])))

# Sort the selected values
selected_values, _ = torch.sort(selected_values)

return selected_values.to(dtype).to(device)
Loading