Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Clean up FP6-LLM" #338

Merged
merged 1 commit into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ tabulate # QOL for printing tables to stdout

# Custom CUDA Extensions
ninja

# for FP6-LLM (can be removed once we remove fp16_to_fp6_original())
qtorch
2 changes: 2 additions & 0 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ torchao.dtypes

to_nf4
UInt4Tensor
to_float6_e3m2
from_float6_e3m2

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
extra_link_args = ["-fopenmp"]
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down
134 changes: 134 additions & 0 deletions test/dtypes/test_float6_e3m2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

try:
import torchao.ops
except RuntimeError:
pytest.skip("torchao.ops not available")


from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2


_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestFloat6E3M2(TestCase):

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
@parametrize(
"input_output",
[
(0.0, 0b000000), # exact values
(1.0, 0b001100), # normal numbers
(1.25, 0b001101),
(28.0, 0b011111), # max
(0.1875, 0b000011), # subnormal number
(0.0625, 0b000001), # min
(29.0, 0b011111), # normal round down
(26.0, 0b011110), # normal round to nearest even
(0.1251, 0b000010), # subnormal round down
(0.0314, 0b000001), # subnormal round up
(0.03, 0b000000), # underflow
],
)
def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output):
input, output = input_output
input = torch.tensor(input, device=device, dtype=dtype)
assert to_float6_e3m2(input, no_bit_packing=True).item() == output

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype):
x = torch.randn(128, 128, device=device, dtype=dtype)
results_unpacked = to_float6_e3m2(x, no_bit_packing=True)
results_packed = to_float6_e3m2(x)

val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1)
bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011
bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222
bits2 = (val2 << 6) | (val3); # 2233 3333

expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2)
assert (results_packed == expected_packed).all()

@parametrize("device", _DEVICES)
@parametrize("shape", [(), (0,), (10,), (20, 20)])
def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape):
x = torch.randn(shape, device=device)
result = to_float6_e3m2(x, no_bit_packing=True)
assert result.shape == shape

@parametrize("device", _DEVICES)
@parametrize("shape", [(4,), (20, 20)])
def test_to_float6_e3m2_bit_packing_shape(self, device, shape):
x = torch.randn(shape, device=device)
result = to_float6_e3m2(x)
assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,)

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
@parametrize("no_bit_packing", [False, True])
def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing):
x = torch.randn(20, 20, device=device, dtype=dtype)
expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing)

to_float6_e3m2_compiled = torch.compile(to_float6_e3m2)
actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
@parametrize(
"input_output",
[
(0b000000, 0.0),
(0b001100, 1.0),
(0b011111, 28.0), # max
(0b000001, 0.0625), # min
(0b001110, 1.5),
(0b000011, 0.1875), # subnormal
],
)
def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output):
input, output = input_output
input = torch.tensor(input, device=device, dtype=torch.uint8)
assert from_float6_e3m2(input, no_bit_packing=True).item() == output

@parametrize("device", _DEVICES)
def test_from_float6_e3m2_bit_packing_correctness(self, device):
x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8)
actual = from_float6_e3m2(x)

bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1)
x_unpacked0 = bits0 >> 2
x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4)
x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6)
x_unpacked3 = bits2 & 0x3F

x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2)
expected = from_float6_e3m2(x_unpacked, no_bit_packing=True)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
@parametrize("no_bit_packing", [False, True])
def test_from_float6_e3m2_compile(self, device, no_bit_packing):
x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8)
expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing)

from_float6_e3m2_compiled = torch.compile(from_float6_e3m2)
actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFloat6E3M2)


if __name__ == "__main__":
run_tests()
29 changes: 2 additions & 27 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

torch.manual_seed(0)

Expand Down Expand Up @@ -320,7 +322,6 @@ def test_fp4_pack_unpack():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
Expand All @@ -330,7 +331,6 @@ def test_fp4_triton_unscaled_cast():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand Down Expand Up @@ -386,28 +386,3 @@ def test_fp6_values(dtype_name):
else:
raise AssertionError("unsupported")
torch.testing.assert_close(f32, f32_ref, rtol=0, atol=0)


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")),
]
)
@pytest.mark.parametrize(
"f32_val,f6_e3m2_enc",
[
(29.0, 0b011111), # normal round down
(26.0, 0b011110), # normal round to nearest even
(0.1251, 0b000010), # subnormal round down
(0.0314, 0b000001), # subnormal round up
(0.03, 0b000000), # underflow
]
)
def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(f32_val, device=device))
assert f6_e3m2_unpacked.item() == f6_e3m2_enc

f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)
27 changes: 10 additions & 17 deletions test/quantization/test_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@
parametrize,
run_tests,
)
from torchao.quantization.fp6_llm import (
to_tc_float6_e3m2,
from_tc_float6_e3m2,
_to_tc_float6_e3m2_ref,
Fp6LlmLinear,
convert_fp6_llm,
)
from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2
from torchao.quantization.fp6_llm import to_tc_float6_e3m2, from_tc_float6_e3m2, Fp6LlmLinear, convert_fp6_llm
from torchao.ops import prepack_fp6_weight


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
Expand All @@ -25,9 +20,9 @@ class TestFp6LlmLinear(TestCase):
def test_to_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)

expected = _to_tc_float6_e3m2_ref(x)
expected = prepack_fp6_weight(to_float6_e3m2(x.cpu()).view(torch.int32)).view(torch.uint8)
actual = to_tc_float6_e3m2(x)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(actual.view(-1).cpu(), expected.view(-1))

@parametrize("device", _DEVICES)
def test_to_tc_float6_e3m2_compile(self, device):
Expand All @@ -40,20 +35,18 @@ def test_to_tc_float6_e3m2_compile(self, device):
@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_correctness(self, device):
x = torch.randn(256, 64, device=device)
x = from_float6_e3m2(to_float6_e3m2(x)) # quantize and dequantize so that the values are exactly representable in FP6

# quantize and dequantize so that the values are exactly representable in FP6
x = f6_e3m2_unpacked_to_f32(f32_to_f6_e3m2_unpacked(x))

actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x))
actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x), *x.shape)
torch.testing.assert_close(actual, x)

@parametrize("device", _DEVICES)
def test_from_tc_float6_e3m2_compile(self, device):
M, N = 256, 64
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)
x = torch.randint(256, size=(M * N * 3 // 4,), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x)
actual = torch.compile(from_tc_float6_e3m2)(x)
expected = from_tc_float6_e3m2(x, M, N)
actual = torch.compile(from_tc_float6_e3m2)(x, M, N)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
76 changes: 66 additions & 10 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
import torchao
from torchao.quantization.fp6_llm import from_tc_float6_e3m2
from torchao.utils import TORCH_VERSION_AFTER_2_4
import unittest
from parameterized import parameterized
import pytest
Expand All @@ -18,38 +18,94 @@
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
class TestOps(TestCase):
def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device):
def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

def _create_fp6_inputs(self, BS: int, OC: int, IC: int):
# Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
fp16_scale = torch.rand(OC).half() + 0.5
fp16_activation = torch.rand(BS, IC).half() + 0.5
return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device)
return fp6_weight, fp16_scale, fp16_activation

def test_prepack_fp6_weight(self):
OC = 256
IC = 256
fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.prepack_fp6_weight(fp6_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16_to_fp6_original(self):
OC = 256
IC = 256
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)

# the original FP16->FP6 kernel checks for overflow/underflow
fp16_weight.clip_(-28.0, 28.0)
fp16_weight[fp16_weight.abs() < 0.0625] = 0.0

# smoke test
torchao.ops.fp16_to_fp6_original(fp16_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
BS = 2
OC = 256
IC = 256
splitK = 1
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

# smoke test
torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda")
fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC)

fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()

results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK)
results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = fp16_activation @ fp16_weight.T
fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = act_cuda @ fp16_weight.cuda().T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
Expand Down
Loading
Loading