Skip to content

Commit 5f2701b

Browse files
authored
Clean up FP6-LLM (pytorch#304)
* override load from state dict * fix prefix * migrate to mx primitive * remove unneeded code * comment out test * remove * add rounding test for f6_e3m2 * update tests * remove openmp flag * update benchmark script * test negative number * remove qtorch dep * fix type casting * add view * fix strange pytest behavior * only skip tests requiring PyTorch 2.4 * remove weight loading magic
1 parent d36b3c4 commit 5f2701b

File tree

17 files changed

+87
-1190
lines changed

17 files changed

+87
-1190
lines changed

dev-requirements.txt

-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,3 @@ tabulate # QOL for printing tables to stdout
1414

1515
# Custom CUDA Extensions
1616
ninja
17-
18-
# for FP6-LLM (can be removed once we remove fp16_to_fp6_original())
19-
qtorch

docs/source/api_ref_dtypes.rst

-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ torchao.dtypes
1212

1313
to_nf4
1414
UInt4Tensor
15-
to_float6_e3m2
16-
from_float6_e3m2
1715

1816
..
1917
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring

setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@ def get_extensions():
4949
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
5050
extension = CUDAExtension if use_cuda else CppExtension
5151

52-
extra_link_args = ["-fopenmp"]
52+
extra_link_args = []
5353
extra_compile_args = {
5454
"cxx": [
5555
"-O3" if not debug_mode else "-O0",
5656
"-fdiagnostics-color=always",
57-
"-fopenmp",
5857
],
5958
"nvcc": [
6059
"-O3" if not debug_mode else "-O0",

test/dtypes/test_float6_e3m2.py

-134
This file was deleted.

test/prototype/mx_formats/test_custom_cast.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646
from torchao.prototype.mx_formats.mx_tensor import MXTensor
4747
from torchao.utils import TORCH_VERSION_AFTER_2_4
4848

49-
if not TORCH_VERSION_AFTER_2_4:
50-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
5149

5250
torch.manual_seed(0)
5351

@@ -322,6 +320,7 @@ def test_fp4_pack_unpack():
322320

323321
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
324322
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
323+
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
325324
def test_fp4_triton_unscaled_cast():
326325
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
327326
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
@@ -331,6 +330,7 @@ def test_fp4_triton_unscaled_cast():
331330

332331
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
333332
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
333+
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
334334
def test_fp4_triton_scaled_cast():
335335
size = (256,)
336336
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
@@ -386,3 +386,28 @@ def test_fp6_values(dtype_name):
386386
else:
387387
raise AssertionError("unsupported")
388388
torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0)
389+
390+
391+
@pytest.mark.parametrize(
392+
"device",
393+
[
394+
"cpu",
395+
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")),
396+
]
397+
)
398+
@pytest.mark.parametrize(
399+
"f32_val,f6_e3m2_enc",
400+
[
401+
(29.0, 0b011111), # normal round down
402+
(26.0, 0b011110), # normal round to nearest even
403+
(0.1251, 0b000010), # subnormal round down
404+
(0.0314, 0b000001), # subnormal round up
405+
(0.03, 0b000000), # underflow
406+
]
407+
)
408+
def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
409+
f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(f32_val, device=device))
410+
assert f6_e3m2_unpacked.item() == f6_e3m2_enc
411+
412+
f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
413+
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)

test/quantization/test_fp6_llm.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
parametrize,
88
run_tests,
99
)
10-
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2
11-
from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm
12-
from torchao.ops import prepack_fp6_weight
10+
from torchao.quantization.fp6_llm import (
11+
to_tc_float6_e3m2,
12+
from_tc_float6_e3m2,
13+
_to_tc_float6_e3m2_ref,
14+
Fp6LlmLinear,
15+
convert_fp6_llm,
16+
)
17+
from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked
1318

1419

1520
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
@@ -20,9 +25,9 @@ class TestFp6LlmLinear(TestCase):
2025
def test_to_tc_float6_e3m2_correctness(self, device):
2126
x = torch.randn(256, 64, device=device)
2227

23-
expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8)
28+
expected = _to_tc_float6_e3m2_ref(x)
2429
actual = to_tc_float6_e3m2(x)
25-
torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1))
30+
torch.testing.assert_close(actual, expected)
2631

2732
@parametrize("device", _DEVICES)
2833
def test_to_tc_float6_e3m2_compile(self, device):
@@ -35,18 +40,20 @@ def test_to_tc_float6_e3m2_compile(self, device):
3540
@parametrize("device", _DEVICES)
3641
def test_from_tc_float6_e3m2_correctness(self, device):
3742
x = torch.randn(256, 64, device=device)
38-
x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6
3943

40-
actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape)
44+
# quantize and dequantize so that the values are exactly representable in FP6
45+
x = f6_e3m2_unpacked_to_f32(f32_to_f6_e3m2_unpacked(x))
46+
47+
actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x))
4148
torch.testing.assert_close(actual, x)
4249

4350
@parametrize("device", _DEVICES)
4451
def test_from_tc_float6_e3m2_compile(self, device):
4552
M, N = 256, 64
46-
x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device)
53+
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)
4754

48-
expected = from_tc_float6_e3m2(x, M, N)
49-
actual = torch.compile(from_tc_float6_e3m2)(x, M, N)
55+
expected = from_tc_float6_e3m2(x)
56+
actual = torch.compile(from_tc_float6_e3m2)(x)
5057
torch.testing.assert_close(actual, expected)
5158

5259
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

test/test_ops.py

+10-66
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
33
from torch.testing._internal.optests import opcheck
44
import torchao
5-
from torchao.utils import TORCH_VERSION_AFTER_2_4
5+
from torchao.quantization.fp6_llm import from_tc_float6_e3m2
66
import unittest
77
from parameterized import parameterized
88
import pytest
@@ -18,94 +18,38 @@
1818
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
1919
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
2020
class TestOps(TestCase):
21-
def _create_tensors_with_iou(self, N, iou_thresh):
22-
# force last box to have a pre-defined iou with the first box
23-
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
24-
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
25-
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
26-
# Adjust the threshold upward a bit with the intent of creating
27-
# at least one box that exceeds (barely) the threshold and so
28-
# should be suppressed.
29-
boxes = torch.rand(N, 4) * 100
30-
boxes[:, 2:] += boxes[:, :2]
31-
boxes[-1, :] = boxes[0, :]
32-
x0, y0, x1, y1 = boxes[-1].tolist()
33-
iou_thresh += 1e-5
34-
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
35-
scores = torch.rand(N)
36-
return boxes, scores
37-
38-
def _create_fp6_inputs(self, BS: int, OC: int, IC: int):
21+
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
3922
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
4023
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
4124
fp16_scale = torch.rand(OC).half() + 0.5
4225
fp16_activation = torch.rand(BS, IC).half() + 0.5
43-
return fp6_weight, fp16_scale, fp16_activation
44-
45-
def test_prepack_fp6_weight(self):
46-
OC = 256
47-
IC = 256
48-
fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC)
49-
50-
# smoke test
51-
torchao.ops.prepack_fp6_weight(fp6_weight)
52-
53-
# comprehensive testing
54-
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
55-
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)
56-
57-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
58-
def test_fp16_to_fp6_original(self):
59-
OC = 256
60-
IC = 256
61-
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)
62-
63-
# the original FP16->FP6 kernel checks for overflow/underflow
64-
fp16_weight.clip_(-28.0, 28.0)
65-
fp16_weight[fp16_weight.abs() < 0.0625] = 0.0
66-
67-
# smoke test
68-
torchao.ops.fp16_to_fp6_original(fp16_weight)
69-
70-
# comprehensive testing
71-
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
72-
opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils)
26+
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)
7327

7428
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
7529
def test_fp16act_fp6weight_linear(self):
7630
BS = 2
7731
OC = 256
7832
IC = 256
7933
splitK = 1
80-
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)
81-
82-
fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
83-
act_cuda = fp16_activation.cuda()
84-
weight_cuda = fp6_weight_packed.cuda()
85-
scale_cuda = fp16_scale.cuda()
34+
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
8635

8736
# smoke test
88-
torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
37+
torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
8938

9039
# comprehensive testing
9140
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
92-
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)
41+
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
9342

9443
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
9544
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
9645
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
9746
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
98-
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)
99-
100-
fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
101-
act_cuda = fp16_activation.cuda()
102-
weight_cuda = fp6_weight_packed.cuda()
103-
scale_cuda = fp16_scale.cuda()
47+
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
10448

105-
results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
49+
results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
10650

107-
fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
108-
results_fp16 = act_cuda @ fp16_weight.cuda().T
51+
fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
52+
results_fp16 = fp16_activation @ fp16_weight.T
10953

11054
error = (results_fp6 - results_fp16).abs()
11155
relative_error = error / results_fp16.abs()
File renamed without changes.

0 commit comments

Comments
 (0)