Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MXFP6 packing and fused unpack-dequantise kernel (conflicts resolved) #1810

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
f32_to_f6_e3m2_unpacked,
get_bits,
pack_uint4,
pack_uint6,
triton_f4_to_bf16,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
unpack_uint4,
)
from torchao.prototype.mx_formats.fp_format_spec import (
Expand Down Expand Up @@ -411,3 +414,41 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):

f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e2m3_pack_unpack():
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e2m3_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e3m2_pack_unpack():
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e3m2_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
26 changes: 13 additions & 13 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def test_linear_eager(elem_dtype, bias, input_shape):
"""
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
grad_shape[-1] = 6
grad_shape[-1] = 8

m = nn.Sequential(
nn.Linear(8, 6, bias=bias, device="cuda"),
nn.Linear(8, 8, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
config = MXLinearConfig(
block_size=2,
block_size=4,
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
Expand Down Expand Up @@ -141,14 +141,14 @@ def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 6)
grad_shape = (2, 8)
elem_dtype = torch.float8_e4m3fn

m = nn.Sequential(
nn.Linear(4, 6, bias=True, device="cuda"),
nn.Linear(6, 6, bias=True, device="cuda"),
nn.Linear(4, 8, bias=True, device="cuda"),
nn.Linear(8, 8, bias=True, device="cuda"),
)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m, config=config)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
Expand Down Expand Up @@ -178,13 +178,13 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
M, K, N = 4, 8, 6
M, K, N = 4, 8, 8
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
Expand Down Expand Up @@ -229,10 +229,10 @@ def test_inference_linear(elem_dtype, bias, input_shape):
"""
Smoke test for inference linear module with mx weight
"""
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
Expand All @@ -257,10 +257,10 @@ def test_inference_compile_simple(elem_dtype):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)
m_mx = torch.compile(m_mx, fullgraph="true")

Expand Down
81 changes: 57 additions & 24 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

from torchao.prototype.mx_formats import config
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.custom_cast import pack_uint4
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
from torchao.prototype.mx_formats.mx_tensor import (
E8M0_EXPONENT_NAN_VAL,
MXTensor,
Expand Down Expand Up @@ -75,7 +76,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_hello_world(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -92,7 +93,7 @@ def test_realistic_numerics(elem_dtype, scale_calculation_mode):
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_zeros(elem_dtype):
data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -102,7 +103,7 @@ def test_some_zeros(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
data[0, :] = 0.0
data[:, 2] = 0.0
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -114,9 +115,9 @@ def test_exponent_nan_in(elem_dtype):
value is set to is NaN
"""
tensor_hp = torch.tensor(
[float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16
)
block_size = 2
block_size = 4
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
Expand All @@ -129,18 +130,30 @@ def test_exponent_nan_out(elem_dtype):
If block exponent value is NaN, the MX tensor block value is NaN
"""
scale_e8m0_bits = torch.tensor(
[E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda"
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
)

block_size = 4

if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda"
) # noqa: E501
elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
if config.pack_fp6:
data_bits = data_bits.reshape(-1, block_size)
data_bits = pack_uint6(data_bits)
elif elem_dtype == DTYPE_FP4:
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
data_bits = pack_uint4(data_bits)
else:
raise AssertionError("unsupported")
block_size = 2
block_size = 4
use_fp4_custom_triton_dequant_kernel = False
tensor_mx = MXTensor(
scale_e8m0_bits,
Expand All @@ -152,8 +165,8 @@ def test_exponent_nan_out(elem_dtype):
MXGemmKernelChoice.EMULATED,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
assert not torch.any(torch.isnan(tensor_hp[2:]))
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -162,24 +175,26 @@ def test_ranks(elem_dtype):
"""
The reshaping logic works for various ranks
"""
B = 2
shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2))
B = 4
shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4))
for s in shapes:
tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_block_sizes(elem_dtype):
@pytest.mark.parametrize("B", [1, 4, 32])
def test_block_sizes(elem_dtype, B):
"""
Smoke test for various block sizes
"""
for B in (1, 2, 32):
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -224,14 +239,32 @@ def test_cast_autograd(elem_dtype):
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_view(elem_dtype):
x = torch.randn(1, 2, 4)
block_size = 2
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
x_mx_2 = x_mx.view(2, 4) # noqa: F841


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
@pytest.mark.parametrize("do_fp6_packing", [False, True])
def test_fp6_packing(elem_dtype, do_fp6_packing):
config.pack_fp6 = do_fp6_packing
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
if config.pack_fp6:
expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4])
else:
expected_packed_shape = x.shape
config.pack_fp6 = True

assert x_mx._data.shape == expected_packed_shape


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
Expand All @@ -253,7 +286,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x = torch.randn(*shape, dtype=hp_dtype, device="cuda")
else:
x = torch.zeros(*shape, dtype=hp_dtype, device="cuda")
block_size = 2
block_size = 4
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)

x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SUPPORTED_ELEM_DTYPES,
)

pack_fp6 = True

class MXGemmKernelChoice(Enum):
# always available - MX operands are dequantized and a high precision
Expand Down
Loading