Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable torch.compile for mxfp8_cublas recipe #1841

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
75 changes: 69 additions & 6 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn

from torchao.float8.float8_utils import is_row_major
from torchao.prototype.mx_formats.config import (
MXLinearConfig,
MXLinearRecipeName,
Expand All @@ -24,14 +25,14 @@
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down Expand Up @@ -169,11 +170,18 @@ def test_activation_checkpointing():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
is_sm_at_least_100(),
reason="triton does not work yet on CUDA capability 10.0",
)
@pytest.mark.parametrize(
"recipe_name",
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
[
"mxfp8_emulated",
"mxfp4_emulated",
"mxfp8_cublas",
"mxfp8_cutlass",
"mxfp4_cutlass",
],
)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
Expand All @@ -190,9 +198,9 @@ def test_linear_compile(recipe_name, bias):
if not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for MX gemms")

if bias and recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
if bias and recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
# TODO(future PR): fix this, things are clearly broken with bias=True
pytest.skip("this test is broken for cutlass recipes with bias=True")
pytest.skip("this test is broken for non-emulated recipes with bias=True")

M, K, N = 128, 256, 512
input_shape = (M, K)
Expand Down Expand Up @@ -285,6 +293,61 @@ def test_inference_compile_simple(elem_dtype):
assert sqnr >= 13.5


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(),
reason="triton does not work yet on CUDA capability 10.0",
)
@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MX gemms require CUDA capability 10.0",
)
def test_scaled_mm_wrapper():
# today, e8m0 isn't supported in torchinductor or triton
# for now, work around this by creating a wrapper around torch._scaled_mm
# which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper
from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales

M, K, N = 128, 256, 512
BLOCK_SIZE = 32
a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn)

a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)

out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16)

def wrapped(a, b, a_scale, b_scale, out_dtype):
if is_row_major(b.stride()):
b = b.t().contiguous().t()
res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype)
return res

wrapped = torch.compile(wrapped)

# correct memory format of `b`
out2 = wrapped(
a,
b.t(),
a_scale.view(torch.uint8),
b_scale.view(torch.uint8),
out_dtype=torch.bfloat16,
)
torch.testing.assert_close(out, out2, atol=0, rtol=0)

# incorrect memory format of `b`
b_col_major = b.t().contiguous().t()
out3 = wrapped(
a,
b_col_major.t(),
a_scale.view(torch.uint8),
b_scale.view(torch.uint8),
out_dtype=torch.bfloat16,
)
torch.testing.assert_close(out, out3, atol=0, rtol=0)


def test_filter_fn():
m1 = nn.Sequential(
nn.Linear(32, 32),
Expand Down
36 changes: 33 additions & 3 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,41 @@
tensor_size_hp_to_fp4x2,
)
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

aten = torch.ops.aten

MX_OPS_TABLE: Dict[Any, Any] = {}

if TORCH_VERSION_AT_LEAST_2_5:

@torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=())
def _scaled_mm_with_uint8_scales(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
"""
Until https://github.com/pytorch/pytorch/issues/147873 is done, we need to
work around the lack of support for `torch.float8_e8m0fnu` in
torchinductor. We do so by hiding the cast of scales to e8m0 inside a
custom op.
"""
# cast back to e8m0 where torchinductor can't see it
a_scale = a_scale.view(torch.float8_e8m0fnu)
b_scale = b_scale.view(torch.float8_e8m0fnu)
res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype)
return res

@_scaled_mm_with_uint8_scales.register_fake
def _(a, b, a_scale, b_scale, out_dtype):
m, k = a.shape
k2, n = b.shape
res = torch.empty(m, n, dtype=out_dtype, device=a.device)
return res


def implements(aten_ops):
"""Register aten ops to the mx op table"""
Expand Down Expand Up @@ -83,11 +113,11 @@ def mx_mm(aten_op, args, kwargs=None):
if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
res = torch._scaled_mm(
res = _scaled_mm_with_uint8_scales(
a._data,
b._data,
a_scale_block.view(torch.float8_e8m0fnu),
b_scale_block.view(torch.float8_e8m0fnu),
a_scale_block,
b_scale_block,
out_dtype=torch.bfloat16,
)
else:
Expand Down
Loading