Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

pytest.importorskip("triton", reason="Triton required to run this test")

from torchao.prototype.moe_training.utils import generate_jagged_offs
from torchao.prototype.moe_training.utils import (
_to_mxfp8_per_group_colwise,
_to_mxfp8_per_group_rowwise,
generate_jagged_offs,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

# We need to skip before doing any imports which would use triton, since
Expand All @@ -30,8 +34,9 @@
from torchao.float8.float8_training_tensor import LinearMMConfig
from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.scaled_grouped_mm import (
_emulated_mxfp8_scaled_grouped_mm_2d_2d,
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
_scaled_grouped_mm,
emulated_mxfp8_scaled_grouped_mm,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.testing.utils import skip_if_rocm
Expand Down Expand Up @@ -223,7 +228,7 @@ def compute_reference_forward(
@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M)
Expand All @@ -242,7 +247,7 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)

ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
out = emulated_mxfp8_scaled_grouped_mm(
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
)

Expand All @@ -252,6 +257,53 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):


@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("M", (1024, 4096))
@pytest.mark.parametrize("N", (1024, 4096))
@pytest.mark.parametrize("num_experts", (8, 16))
def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
# Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x
block_size = 32
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
grad_out_t = grad_out.t().contiguous()
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone()

# bf16 reference grouped gemm
ref_out = torch._grouped_mm(
grad_out_t_ref,
x_ref,
offs=offs_ref,
out_dtype=torch.bfloat16,
)

# mxpf8 grouped gemm
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
grad_out_t_mx, grad_out_t_scale = _to_mxfp8_per_group_rowwise(
grad_out_t,
offs=offs,
block_size=block_size,
)
x_mx, x_scale = _to_mxfp8_per_group_colwise(
x,
offs=offs,
block_size=block_size,
)
out = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
grad_out_t_mx,
grad_out_t_scale,
x_mx,
x_scale,
offs=offs,
out_dtype=torch.bfloat16,
block_size=block_size,
)

sqnr = compute_error(ref_out, out)
min_sqnr = 27.0
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"


@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):
Expand Down
121 changes: 120 additions & 1 deletion torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def forward(

# Store what we need for backward.
ctx.save_for_backward(A, B_t, offs)
ctx.block_size = block_size
ctx.out_dtype = out_dtype

# Perform scaled grouped GEMM and return result.
Expand All @@ -317,7 +318,7 @@ def forward(
return out

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
def backward(ctx, grad_out: torch.Tensor):
raise NotImplementedError


Expand Down Expand Up @@ -352,6 +353,27 @@ def emulated_mxfp8_scaled_grouped_mm(
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
block_size: int = 32,
) -> torch.Tensor:
if A_mx.ndim == 2 and B_t_mx.ndim == 3:
return _emulated_mxfp8_scaled_grouped_mm_2d_3d(
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
)
elif A_mx.ndim == 2 and B_t_mx.ndim == 2:
return _emulated_mxfp8_scaled_grouped_mm_2d_2d(
A_mx, A_scale, B_t_mx, B_t_scale, offs, out_dtype, block_size
)
else:
raise NotImplementedError


def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
A_mx: torch.Tensor,
A_scale: torch.Tensor,
B_t_mx: torch.Tensor,
B_t_scale: torch.Tensor,
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
block_size: int = 32,
) -> torch.Tensor:
# Dequantize input
# A_mx shape: (M, K)
Expand Down Expand Up @@ -397,3 +419,100 @@ def emulated_mxfp8_scaled_grouped_mm(
# Perform bf16 grouped GEMM.
out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype)
return out


def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
A_mx: torch.Tensor, # (M, K)
A_scale: torch.Tensor, # (M, K//block_size)
B_mx: torch.Tensor, # (K, N)
B_scale: torch.Tensor, # (K//block_size, N)
offs: torch.Tensor,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
block_size: int = 32,
) -> torch.Tensor:
assert A_mx.ndim == 2, "A must be 2D"
assert B_mx.ndim == 2, "B must be 2D"
A = torch.zeros(
A_mx.shape,
dtype=torch.bfloat16,
device=A_mx.device,
requires_grad=A_mx.requires_grad,
)
B = torch.zeros(
B_mx.shape,
dtype=torch.bfloat16,
device=B_mx.device,
requires_grad=B_mx.requires_grad,
)

# Dequantize input per each scaling group
scales_start_idx = 0
group_start_idx = 0
for group_end_idx in offs.tolist():
group_size = group_end_idx - group_start_idx
scale_group_size = group_size // block_size
if group_size == 0:
group_start_idx = group_end_idx
continue

# -- Dequantize A tensor
# A_group shape: (M, group_size)
# A_scale shape: (M, group_size//block_size)
A_group = A_mx[:, group_start_idx:group_end_idx]
A_group_shape = A_group.shape

# Get scales for this group.
# scales shape: (M, group_size//block_size)
scales = A_scale[:, scales_start_idx : scales_start_idx + scale_group_size]

# Reshape to be able to do per-scaling group multiplication
# A_group shape: (M, group_size//block_size, block_size)
# A_scale shape: (M, group_size//block_size, 1)
A_group = A_group.reshape(
*A_group.shape[:-1], A_group.shape[-1] // block_size, block_size
)
scales = scales.unsqueeze(-1)

# Rescale and cast to bfloat16
A_group = A_group.to(torch.bfloat16) * scales.to(torch.bfloat16)

# Reshape back to original shape and store in dequantized A buffer
# A shape: (M, group_size)
A_group = A_group.reshape(A_group_shape)
A[:, group_start_idx:group_end_idx] = A_group

# -- Dequantize B tensor
# B_group shape is (group_size, N)
B_group = B_mx[group_start_idx:group_end_idx, :]
B_group_shape = B_group.shape

# Scales shape is (group_size//block_size, N)
scales = B_scale[scales_start_idx : scales_start_idx + scale_group_size, :]

# Transpose B to get scaling group on rightmost dim, to make things easier
# B_group_shape = (N, group_size)
# scales shape = N, group_size//block_size)
B_group, scales = B_group.transpose(-2, -1), scales.transpose(-2, -1)

# Reshape B to be able to do per-scaling group multiplication
# B_group shape: (N, group_size//block_size, block_size)
# scales shape: (N, group_size//block_size, 1)
B_group = B_group.reshape(
*B_group.shape[:-1], B_group.shape[-1] // block_size, block_size
)
scales = scales.unsqueeze(-1)

# Cast to bf16 and perform scaling
B_group = B_group.to(torch.bfloat16) * scales.to(torch.bfloat16)

# Reshape B_group back to original shape and store in dequantized B buffer
B_group = B_group.reshape(B_group_shape[1], B_group_shape[0]).transpose(-2, -1)
B[group_start_idx:group_end_idx, :] = B_group

# Increment group start and scale start indices
group_start_idx = group_end_idx
scales_start_idx += scale_group_size

# Perform bf16 grouped GEMM using dequantized A and B.
out = torch._grouped_mm(A, B, offs=offs, out_dtype=out_dtype)
return out
108 changes: 104 additions & 4 deletions torchao/prototype/moe_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.mx_formats.mx_tensor import to_mx


# --- float8 rowwise scaling ---
def _to_2d_jagged_float8_tensor_colwise(
A_col_major: torch.Tensor,
offs: torch.Tensor,
Expand Down Expand Up @@ -143,6 +145,104 @@ def _to_2d_jagged_float8_tensor_rowwise(
return x_fp8, x_scales


# --- mxfp8 scaling ---
def _to_mxfp8_per_group_rowwise(
x: torch.Tensor,
offs: torch.Tensor,
block_size: int = 32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a reference implementation used for testing correctness, it is not performant.

This function converts the 2D input tensor a mxpf8 tensor along dim 0 with per-token-group scaling,
where groups are determined based on the offsets.

Args:
A (torch.Tensor): The input tensor to be converted to a jagged mxfp8 tensor.

Returns:
A tuple containing the jagged mxpf8 tensor and the scales used for the conversion.
"""
assert x.ndim == 2, "input tensor must be 2D"
assert offs.numel() > 0, "offs must be non-empty"

x_mx = torch.empty_like(x, dtype=torch.float8_e4m3fn)
x_scales = None

start_idx = 0
for end_idx in offs.tolist():
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)

# Perform mxfp8 conversion on logically distinct subtensor.
scales, mx_subtensor = to_mx(
subtensor.contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Store this portion of the resulting mxfp8 tensor and scales.
x_mx[:, start_idx:end_idx] = mx_subtensor
if x_scales is None:
x_scales = scales.view(torch.uint8) # Needed to support cat op below
else:
x_scales = torch.cat((x_scales, scales.view(torch.uint8)), dim=1)

# Update start index for next group.
start_idx = end_idx

return x_mx, x_scales.view(torch.float8_e8m0fnu)


def _to_mxfp8_per_group_colwise(
A_col_major: torch.Tensor, # (K, N)
offs: torch.Tensor,
block_size: int = 32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a reference implementation used for testing correctness, it is not performant.

This function converts the 2D input tensor a mxpf8 tensor along dim 1 with per-token-group scaling,
where groups are determined based on the offsets.

Args:
A (torch.Tensor): The input tensor to be converted to a mxfp8 tensor.

Returns:
A tuple containing the mxpf8 tensor and the scales used for the conversion.
"""
assert A_col_major.ndim == 2, "A must be 2D"
assert offs.numel() > 0, "offs must be non-empty"

A_mx = torch.empty_like(A_col_major, dtype=torch.float8_e4m3fn)
A_scales = None

start_idx = 0
for end_idx in offs.tolist():
# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, N)

# Convert to mxfp8 along dim1, by transposing, converting, and transposing back.
scales, mx_subtensor = to_mx(
subtensor.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)
scales, mx_subtensor = scales.transpose(-2, -1), mx_subtensor.transpose(-2, -1)

# Store this portion of the resulting mxfp8 tensor and scales.
A_mx[start_idx:end_idx, :] = mx_subtensor
if A_scales is None:
A_scales = scales.view(torch.uint8) # Needed to support cat op below
else:
A_scales = torch.cat((A_scales, scales.view(torch.uint8)), dim=0)

# Update start index for next group.
start_idx = end_idx

return A_mx, A_scales.view(torch.float8_e8m0fnu)


def _is_column_major(x: torch.Tensor) -> bool:
"""
This function checks if the input tensor is column-major.
Expand All @@ -157,7 +257,7 @@ def _is_column_major(x: torch.Tensor) -> bool:
return x.stride(-2) == 1 and x.stride(-1) > 1


def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):
"""
Utility function for tests and benchmarks.

Expand All @@ -170,11 +270,11 @@ def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
torch.Tensor: A tensor of length E with the specified properties.
"""
# Ensure M is divisible by 16
if M % 16 != 0:
raise ValueError("M must be divisible by 16")
if M % multiple_of != 0:
raise ValueError(f"M must be divisible by {multiple_of}")

# Generate a list of possible values
possible_values = [i for i in range(0, M + 1, 16)]
possible_values = [i for i in range(multiple_of, M + 1, multiple_of)]

# If E is larger than the number of possible values, raise an error
if E > len(possible_values):
Expand Down
Loading