Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable compile with mxfp8 and mxfp4 cutlass gemm #1838

Merged
merged 17 commits into from
Mar 5, 2025
33 changes: 15 additions & 18 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,30 +171,32 @@ def test_activation_checkpointing():
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize(
"recipe_name",
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
@pytest.mark.parametrize(
"use_autocast",
[
False,
],
)
def test_linear_compile(elem_dtype, bias, use_autocast):
def test_linear_compile(recipe_name, bias):
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if recipe_name in ["mxfp8_emulated", "mxfp8_cutlass"]:
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
M, K, N = 4, 8, 6

if bias and recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
# TODO(future PR): fix this, things are clearly broken with bias=True
pytest.skip("this test is broken for cutlass recipes with bias=True")

M, K, N = 128, 256, 512
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig.from_recipe_name(recipe_name)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
Expand All @@ -203,13 +205,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
x = copy.deepcopy(x_ref)
g = torch.randn(*grad_shape, device="cuda")

if use_autocast:
with torch.autocast("cuda", dtype=torch.bfloat16):
y_ref = m_mx(x_ref)
y = m_mx_c(x)
else:
y_ref = m_mx(x_ref)
y = m_mx_c(x)
y_ref = m_mx(x_ref)
y = m_mx_c(x)
torch.testing.assert_close(y_ref, y, atol=0, rtol=0)

y_ref.backward(g)
Expand Down
12 changes: 10 additions & 2 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
lib.define(
"rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
# Note: we need to add the `torch._C.Tag.needs_fixed_stride_order` tag in order for inductor
# to honor the layout constraints for `b` in the two ops below.
lib.define(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops I meant to add this before

Good catch

"mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor",
tags=[torch._C.Tag.needs_fixed_stride_order],
)
lib.define(
"mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor",
tags=[torch._C.Tag.needs_fixed_stride_order],
)


def register_custom_op(name):
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
from torch.utils._pytree import tree_map

# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
import torchao.ops
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import DTYPE_FP4
Expand Down Expand Up @@ -73,7 +72,9 @@ def mx_mm(aten_op, args, kwargs=None):
if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert a._data.is_contiguous()
assert b._data.t().is_contiguous()

# TODO(future PR): use block_size instead of hardcoding 32
a_scale = a._scale_e8m0.view(M, K // 32)
b_scale = b._scale_e8m0.view(N, K // 32)
Expand Down
Loading