Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MXFP6 packing and fused unpack-dequantise kernel (conflicts resolved) #1810

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from torch.utils._triton import has_triton

import torchao.prototype.mx_formats.config as config
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
Expand All @@ -26,7 +25,10 @@
f32_to_f6_e3m2_unpacked,
get_bits,
pack_uint4,
pack_uint6,
triton_f4_to_bf16,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
unpack_uint4,
)
from torchao.prototype.mx_formats.fp_format_spec import (
Expand Down Expand Up @@ -329,12 +331,12 @@ def test_fp4_triton_unscaled_cast():
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
mxtensor = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4)
mxtensor_ref = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4)
mxtensor_triton = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4,
use_fp4_custom_triton_dequant_kernel=True)

f32_ref = mxtensor.to_dtype(torch.float)
config.use_fp4_custom_triton_dequant_kernel = True
f32_triton = mxtensor.to_dtype(torch.float)
config.use_fp4_custom_triton_dequant_kernel = False
f32_ref = mxtensor_ref.to_dtype(torch.float)
f32_triton = mxtensor_triton.to_dtype(torch.float)
assert torch.all(torch.eq(f32_ref, f32_triton))


Expand Down Expand Up @@ -411,3 +413,41 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):

f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e2m3_pack_unpack():
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e2m3_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e3m2_pack_unpack():
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e3m2_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
26 changes: 13 additions & 13 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def test_linear_eager(elem_dtype, bias, input_shape):
"""
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
grad_shape[-1] = 6
grad_shape[-1] = 8

m = nn.Sequential(
nn.Linear(8, 6, bias=bias, device="cuda"),
nn.Linear(8, 8, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
config = MXLinearConfig(
block_size=2,
block_size=4,
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
Expand Down Expand Up @@ -141,14 +141,14 @@ def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 6)
grad_shape = (2, 8)
elem_dtype = torch.float8_e4m3fn

m = nn.Sequential(
nn.Linear(4, 6, bias=True, device="cuda"),
nn.Linear(6, 6, bias=True, device="cuda"),
nn.Linear(4, 8, bias=True, device="cuda"),
nn.Linear(8, 8, bias=True, device="cuda"),
)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m, config=config)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
Expand Down Expand Up @@ -178,13 +178,13 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
M, K, N = 4, 8, 6
M, K, N = 4, 8, 8
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
Expand Down Expand Up @@ -229,10 +229,10 @@ def test_inference_linear(elem_dtype, bias, input_shape):
"""
Smoke test for inference linear module with mx weight
"""
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
Expand All @@ -257,10 +257,10 @@ def test_inference_compile_simple(elem_dtype):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)
m_mx = torch.compile(m_mx, fullgraph="true")

Expand Down
85 changes: 60 additions & 25 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.custom_cast import pack_uint4
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
from torchao.prototype.mx_formats.mx_tensor import (
E8M0_EXPONENT_NAN_VAL,
MXTensor,
Expand Down Expand Up @@ -75,7 +75,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_hello_world(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -92,7 +92,7 @@ def test_realistic_numerics(elem_dtype, scale_calculation_mode):
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_zeros(elem_dtype):
data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -102,7 +102,7 @@ def test_some_zeros(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
data[0, :] = 0.0
data[:, 2] = 0.0
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -114,33 +114,46 @@ def test_exponent_nan_in(elem_dtype):
value is set to is NaN
"""
tensor_hp = torch.tensor(
[float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16
)
block_size = 2
block_size = 4
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_exponent_nan_out(elem_dtype):
@pytest.mark.parametrize("pack_fp6", [False, True])
def test_exponent_nan_out(elem_dtype, pack_fp6):
"""
If block exponent value is NaN, the MX tensor block value is NaN
"""
scale_e8m0_bits = torch.tensor(
[E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda"
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
)

block_size = 4

if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda"
) # noqa: E501
elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
if pack_fp6:
data_bits = data_bits.reshape(-1, block_size)
data_bits = pack_uint6(data_bits)
elif elem_dtype == DTYPE_FP4:
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
data_bits = pack_uint4(data_bits)
else:
raise AssertionError("unsupported")
block_size = 2
block_size = 4
use_fp4_custom_triton_dequant_kernel = False
tensor_mx = MXTensor(
scale_e8m0_bits,
Expand All @@ -150,10 +163,11 @@ def test_exponent_nan_out(elem_dtype):
torch.float,
use_fp4_custom_triton_dequant_kernel,
MXGemmKernelChoice.EMULATED,
pack_fp6,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
assert not torch.any(torch.isnan(tensor_hp[2:]))
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -162,24 +176,26 @@ def test_ranks(elem_dtype):
"""
The reshaping logic works for various ranks
"""
B = 2
shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2))
B = 4
shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4))
for s in shapes:
tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_block_sizes(elem_dtype):
@pytest.mark.parametrize("B", [1, 4, 32])
def test_block_sizes(elem_dtype, B):
"""
Smoke test for various block sizes
"""
for B in (1, 2, 32):
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -224,14 +240,30 @@ def test_cast_autograd(elem_dtype):
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_view(elem_dtype):
x = torch.randn(1, 2, 4)
block_size = 2
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
x_mx_2 = x_mx.view(2, 4) # noqa: F841


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
@pytest.mark.parametrize("pack_fp6", [False, True])
def test_fp6_packing(elem_dtype, pack_fp6):
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size, pack_fp6=pack_fp6)
if pack_fp6:
expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4])
else:
expected_packed_shape = x.shape

assert x_mx._data.shape == expected_packed_shape


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
Expand All @@ -253,7 +285,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x = torch.randn(*shape, dtype=hp_dtype, device="cuda")
else:
x = torch.zeros(*shape, dtype=hp_dtype, device="cuda")
block_size = 2
block_size = 4
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)

x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
Expand All @@ -269,13 +301,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
to_dtype_c = torch.compile(to_dtype, fullgraph=True)

use_fp4_custom_triton_dequant_kernel = False
pack_fp6 = False
x_mx_dq = to_dtype(
x_mx._data,
x_mx._scale_e8m0,
x_mx._elem_dtype,
x_mx._block_size,
hp_dtype, # noqa: E501
use_fp4_custom_triton_dequant_kernel,
pack_fp6,
)
x_mx_c_dq = to_dtype_c(
x_mx_c._data,
Expand All @@ -284,6 +318,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x_mx_c._block_size,
hp_dtype,
use_fp4_custom_triton_dequant_kernel,
pack_fp6,
)
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)

Expand Down
3 changes: 3 additions & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class MXLinearConfig:
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

# If True, packs 4xFP6 into 3xuint8 containers, using custom triton kernels (fused unpack/dequantize)
pack_fp6 = False

def __post_init__(self):
# validate elem_dtype and its overrides
assert (
Expand Down
Loading