Skip to content

Commit

Permalink
enable torch.compile for mxfp8_cublas recipe
Browse files Browse the repository at this point in the history
Summary:

This PR enables `MXLinear` with `mxfp8_cublas` recipe to use
torch.compile.

The current approach is a short term workaround until
pytorch/pytorch#148461 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around `torch._scaled_mm` which takes `uint8` scales and does the cast to
e8m0 inside the wrapper, where torchinductor can't see it.

Test Plan:

```
// this now works (although performance is not ideal due to #1788)
python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas

// we can also uncomment the hardware check and run the unit test
pytest test/prototype/mx_formats -s -k test_linear_compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 033d817549f80d7d0d8cf549f748411cc1f3ac6a
ghstack-comment-id: 2701679811
Pull Request resolved: #1841
  • Loading branch information
vkuzo committed Mar 5, 2025
1 parent a62c9af commit 3e332e2
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 6 deletions.
64 changes: 61 additions & 3 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn

from torchao.float8.float8_utils import is_row_major
from torchao.prototype.mx_formats.config import (
MXLinearConfig,
MXLinearRecipeName,
Expand All @@ -22,6 +23,7 @@
swap_linear_with_mx_inference_linear,
swap_linear_with_mx_linear,
)
from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
Expand Down Expand Up @@ -169,13 +171,23 @@ def test_activation_checkpointing():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
is_sm_at_least_100() and False,
reason="triton does not work yet on CUDA capability 10.0",
)
@pytest.mark.parametrize(
"recipe_name",
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
# ["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
[
"mxfp8_cublas",
],
)
# @pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize(
"bias",
[
False,
],
)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
def test_linear_compile(recipe_name, bias):
Expand Down Expand Up @@ -281,6 +293,52 @@ def test_inference_compile_simple(elem_dtype):
assert sqnr >= 13.5


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_scaled_mm_wrapper():
# today, e8m0 isn't supported in torchinductor or triton
# for now, work around this by creating a wrapper around torch._scaled_mm
# which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper

M, K, N = 128, 256, 512
BLOCK_SIZE = 32
a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn)

a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)

out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16)

def wrapped(a, b, a_scale, b_scale, out_dtype):
if is_row_major(b.stride()):
b = b.t().contiguous().t()
res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype)
return res

wrapped = torch.compile(wrapped)

# correct memory format of `b`
out2 = wrapped(
a,
b.t(),
a_scale.view(torch.uint8),
b_scale.view(torch.uint8),
out_dtype=torch.bfloat16,
)
torch.testing.assert_close(out, out2, atol=0, rtol=0)

# incorrect memory format of `b`
b_col_major = b.t().contiguous().t()
out3 = wrapped(
a,
b_col_major.t(),
a_scale.view(torch.uint8),
b_scale.view(torch.uint8),
out_dtype=torch.bfloat16,
)
torch.testing.assert_close(out, out3, atol=0, rtol=0)


def test_filter_fn():
m1 = nn.Sequential(
nn.Linear(32, 32),
Expand Down
33 changes: 30 additions & 3 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,33 @@
MX_OPS_TABLE: Dict[Any, Any] = {}


@torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=())
def _scaled_mm_with_uint8_scales(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
"""
TODO write me
"""
# cast back to e8m0 where torchinductor can't see it
a_scale = a_scale.view(torch.float8_e8m0fnu)
b_scale = b_scale.view(torch.float8_e8m0fnu)

res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype)
return res


@_scaled_mm_with_uint8_scales.register_fake
def _(a, b, a_scale, b_scale, out_dtype):
m, k = a.shape
k2, n = b.shape
res = torch.empty(m, n, dtype=out_dtype, device=a.device)
return res


def implements(aten_ops):
"""Register aten ops to the mx op table"""

Expand Down Expand Up @@ -83,11 +110,11 @@ def mx_mm(aten_op, args, kwargs=None):
if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
res = torch._scaled_mm(
res = _scaled_mm_with_uint8_scales(
a._data,
b._data,
a_scale_block.view(torch.float8_e8m0fnu),
b_scale_block.view(torch.float8_e8m0fnu),
a_scale_block,
b_scale_block,
out_dtype=torch.bfloat16,
)
else:
Expand Down

0 comments on commit 3e332e2

Please sign in to comment.