Skip to content

Commit

Permalink
[wip] e2e testing of compile with mxfp8 cutlass gemm
Browse files Browse the repository at this point in the history
Summary:

Does not work yet - bug in inductor?

Test Plan:

```
pytest test/prototype/mx_formats -s -x -k test_linear_compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 9ea94a266e47e80760623fd6e3945e58a4ae02f8
ghstack-comment-id: 2701081457
Pull Request resolved: #1838
  • Loading branch information
vkuzo committed Mar 5, 2025
1 parent 9803094 commit 358e6d7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 23 deletions.
2 changes: 1 addition & 1 deletion benchmarks/float8/profile_lowp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def main(
# get gradient shape
with torch.no_grad():
_ = m_ref(input_tensor)
grad_output = torch.ones_like(_)
grad_output = torch.ones_like(_).contiguous()

m_lowp = copy.deepcopy(m_ref)
if mx_recipe_name is None:
Expand Down
35 changes: 15 additions & 20 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,33 +168,31 @@ def test_activation_checkpointing():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [False, True])
# @pytest.mark.skipif(
# is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
# )
# @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
# @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
@pytest.mark.parametrize("recipe_name", ["mxfp8_cutlass"])
# @pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("bias", [False])
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
@pytest.mark.parametrize(
"use_autocast",
[
False,
],
)
def test_linear_compile(elem_dtype, bias, use_autocast):
def test_linear_compile(recipe_name, bias):
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
elem_dtype = torch.float8_e4m3fn
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
M, K, N = 4, 8, 6
M, K, N = 128, 256, 512
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig.from_recipe_name(recipe_name)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
Expand All @@ -203,14 +201,11 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
x = copy.deepcopy(x_ref)
g = torch.randn(*grad_shape, device="cuda")

if use_autocast:
with torch.autocast("cuda", dtype=torch.bfloat16):
y_ref = m_mx(x_ref)
y = m_mx_c(x)
else:
with torch.no_grad():
y_ref = m_mx(x_ref)
y = m_mx_c(x)
torch.testing.assert_close(y_ref, y, atol=0, rtol=0)
torch.testing.assert_close(y_ref, y, atol=0, rtol=0)
return

y_ref.backward(g)
y.backward(g)
Expand Down
19 changes: 17 additions & 2 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
tensor_size_hp_to_fp4x2,
)
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.float8.float8_utils import is_row_major

aten = torch.ops.aten

Expand Down Expand Up @@ -73,7 +74,21 @@ def mx_mm(aten_op, args, kwargs=None):
if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert a._data.is_contiguous()
assert b._data.t().is_contiguous()

a_data = a._data
b_data = b._data

# note: below does not fix the bug, leaving here for now for completeness during
# further debugging but should be deleted
if False:
# triton not respecing memory layout, fix it manually
if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
b_data = b_data.t().contiguous().t()

# TODO(future PR): use block_size instead of hardcoding 32
a_scale = a._scale_e8m0.view(M, K // 32)
b_scale = b._scale_e8m0.view(N, K // 32)
Expand All @@ -83,8 +98,8 @@ def mx_mm(aten_op, args, kwargs=None):
assert b._elem_dtype == torch.float8_e4m3fn
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
res = torch._scaled_mm(
a._data,
b._data,
a_data,
b_data,
a_scale_block.view(torch.float8_e8m0fnu),
b_scale_block.view(torch.float8_e8m0fnu),
out_dtype=torch.bfloat16,
Expand Down

0 comments on commit 358e6d7

Please sign in to comment.