Skip to content

Commit

Permalink
[wip] e2e testing of compile with mxfp8 cutlass gemm
Browse files Browse the repository at this point in the history
Summary:

Does not work yet - bug in inductor?

Test Plan:

```
pytest test/prototype/mx_formats -s -x -k test_linear_compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 1e94dd3c894a3d975f37286bfb5f5cb142e7d30f
ghstack-comment-id: 2701081457
Pull Request resolved: #1838
  • Loading branch information
vkuzo committed Mar 5, 2025
1 parent 23c3162 commit c81137a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
37 changes: 17 additions & 20 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,33 +168,35 @@ def test_activation_checkpointing():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
# @pytest.mark.skipif(
# is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
# )
@pytest.mark.parametrize(
"recipe_name",
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
@pytest.mark.parametrize(
"use_autocast",
[
False,
],
)
def test_linear_compile(elem_dtype, bias, use_autocast):
def test_linear_compile(recipe_name, bias):
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if recipe_name in ["mxfp8_emulated", "mxfp8_cutlass"]:
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
M, K, N = 4, 8, 6

if bias and recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
# TODO(future PR): fix this, things are clearly broken with bias=True
pytest.skip("this test is broken for cutlass recipes with bias=True")

M, K, N = 128, 256, 512
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig.from_recipe_name(recipe_name)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
Expand All @@ -203,13 +205,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
x = copy.deepcopy(x_ref)
g = torch.randn(*grad_shape, device="cuda")

if use_autocast:
with torch.autocast("cuda", dtype=torch.bfloat16):
y_ref = m_mx(x_ref)
y = m_mx_c(x)
else:
y_ref = m_mx(x_ref)
y = m_mx_c(x)
y_ref = m_mx(x_ref)
y = m_mx_c(x)
torch.testing.assert_close(y_ref, y, atol=0, rtol=0)

y_ref.backward(g)
Expand Down
12 changes: 10 additions & 2 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
lib.define(
"rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
# Note: we need to add the `torch._C.Tag.needs_fixed_stride_order` tag in order for inductor
# to honor the layout constraints for `b` in the two ops below.
lib.define(
"mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor",
tags=[torch._C.Tag.needs_fixed_stride_order],
)
lib.define(
"mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor",
tags=[torch._C.Tag.needs_fixed_stride_order],
)


def register_custom_op(name):
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
from torch.utils._pytree import tree_map

# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
import torchao.ops
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import DTYPE_FP4
Expand Down Expand Up @@ -73,7 +72,9 @@ def mx_mm(aten_op, args, kwargs=None):
if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert a._data.is_contiguous()
assert b._data.t().is_contiguous()

# TODO(future PR): use block_size instead of hardcoding 32
a_scale = a._scale_e8m0.view(M, K // 32)
b_scale = b._scale_e8m0.view(N, K // 32)
Expand Down

0 comments on commit c81137a

Please sign in to comment.