Skip to content

Commit

Permalink
Clean up FP6-LLM (pytorch#304)
Browse files Browse the repository at this point in the history
* override load from state dict

* fix prefix

* migrate to mx primitive

* remove unneeded code

* comment out test

* remove

* add rounding test for f6_e3m2

* update tests

* remove openmp flag

* update benchmark script

* test negative number

* remove qtorch dep

* fix type casting

* add view

* fix strange pytest behavior

* only skip tests requiring PyTorch 2.4

* remove weight loading magic
  • Loading branch information
gau-nernst authored Jun 9, 2024
1 parent 000a0fd commit cd8f647
Show file tree
Hide file tree
Showing 17 changed files with 87 additions and 1,190 deletions.
3 changes: 0 additions & 3 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,3 @@ tabulate # QOL for printing tables to stdout

# Custom CUDA Extensions
ninja

# for FP6-LLM (can be removed once we remove fp16_to_fp6_original())
qtorch
2 changes: 0 additions & 2 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ torchao.dtypes

to_nf4
UInt4Tensor
to_float6_e3m2
from_float6_e3m2

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = ["-fopenmp"]
extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down
134 changes: 0 additions & 134 deletions test/dtypes/test_float6_e3m2.py

This file was deleted.

29 changes: 27 additions & 2 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

torch.manual_seed(0)

Expand Down Expand Up @@ -322,6 +320,7 @@ def test_fp4_pack_unpack():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
Expand All @@ -331,6 +330,7 @@ def test_fp4_triton_unscaled_cast():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand Down Expand Up @@ -386,3 +386,28 @@ def test_fp6_values(dtype_name):
else:
raise AssertionError("unsupported")
torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")),
]
)
@pytest.mark.parametrize(
"f32_val,f6_e3m2_enc",
[
(29.0, 0b011111), # normal round down
(26.0, 0b011110), # normal round to nearest even
(0.1251, 0b000010), # subnormal round down
(0.0314, 0b000001), # subnormal round up
(0.03, 0b000000), # underflow
]
)
def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(f32_val, device=device))
assert f6_e3m2_unpacked.item() == f6_e3m2_enc

f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)
27 changes: 17 additions & 10 deletions test/quantization/test_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
parametrize,
run_tests,
)
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2
from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm
from torchao.ops import prepack_fp6_weight
from torchao.quantization.fp6_llm import (
to_tc_float6_e3m2,
from_tc_float6_e3m2,
_to_tc_float6_e3m2_ref,
Fp6LlmLinear,
convert_fp6_llm,
)
from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
Expand All @@ -20,9 +25,9 @@ class TestFp6LlmLinear(TestCase):
def test_to_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)

expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8)
expected = _to_tc_float6_e3m2_ref(x)
actual = to_tc_float6_e3m2(x)
torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1))
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_to_tc_float6_e3m2_compile(self, device):
Expand All @@ -35,18 +40,20 @@ def test_to_tc_float6_e3m2_compile(self, device):
@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)
x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6

actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape)
# quantize and dequantize so that the values are exactly representable in FP6
x = f6_e3m2_unpacked_to_f32(f32_to_f6_e3m2_unpacked(x))

actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x))
torch.testing.assert_close(actual, x)

@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_compile(self, device):
M, N = 256, 64
x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device)
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x, M, N)
actual = torch.compile(from_tc_float6_e3m2)(x, M, N)
expected = from_tc_float6_e3m2(x)
actual = torch.compile(from_tc_float6_e3m2)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
76 changes: 10 additions & 66 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
import torchao
from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.quantization.fp6_llm import from_tc_float6_e3m2
import unittest
from parameterized import parameterized
import pytest
Expand All @@ -18,94 +18,38 @@
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
class TestOps(TestCase):
def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

def _create_fp6_inputs(self, BS: int, OC: int, IC: int):
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(OC).half() + 0.5
fp16_activation = torch.rand(BS, IC).half() + 0.5
return fp6_weight, fp16_scale, fp16_activation

def test_prepack_fp6_weight(self):
OC = 256
IC = 256
fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.prepack_fp6_weight(fp6_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16_to_fp6_original(self):
OC = 256
IC = 256
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)

# the original FP16->FP6 kernel checks for overflow/underflow
fp16_weight.clip_(-28.0, 28.0)
fp16_weight[fp16_weight.abs() < 0.0625] = 0.0

# smoke test
torchao.ops.fp16_to_fp6_original(fp16_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils)
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

# smoke test
torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")

results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)
results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)

fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = act_cuda @ fp16_weight.cuda().T
fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = fp16_activation @ fp16_weight.T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
Expand Down
File renamed without changes.
Loading

0 comments on commit cd8f647

Please sign in to comment.