|
2 | 2 | from torch.testing._internal.common_utils import TestCase, IS_FBCODE
|
3 | 3 | from torch.testing._internal.optests import opcheck
|
4 | 4 | import torchao
|
5 |
| -from torchao.utils import TORCH_VERSION_AFTER_2_4 |
| 5 | +from torchao.quantization.fp6_llm import from_tc_float6_e3m2 |
6 | 6 | import unittest
|
7 | 7 | from parameterized import parameterized
|
8 | 8 | import pytest
|
|
18 | 18 | @pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
|
19 | 19 | @unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
|
20 | 20 | class TestOps(TestCase):
|
21 |
| - def _create_tensors_with_iou(self, N, iou_thresh): |
22 |
| - # force last box to have a pre-defined iou with the first box |
23 |
| - # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1], |
24 |
| - # then, in order to satisfy ops.iou(b0, b1) == iou_thresh, |
25 |
| - # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh |
26 |
| - # Adjust the threshold upward a bit with the intent of creating |
27 |
| - # at least one box that exceeds (barely) the threshold and so |
28 |
| - # should be suppressed. |
29 |
| - boxes = torch.rand(N, 4) * 100 |
30 |
| - boxes[:, 2:] += boxes[:, :2] |
31 |
| - boxes[-1, :] = boxes[0, :] |
32 |
| - x0, y0, x1, y1 = boxes[-1].tolist() |
33 |
| - iou_thresh += 1e-5 |
34 |
| - boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh |
35 |
| - scores = torch.rand(N) |
36 |
| - return boxes, scores |
37 |
| - |
38 |
| - def _create_fp6_inputs(self, BS: int, OC: int, IC: int): |
| 21 | + def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): |
39 | 22 | # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
|
40 | 23 | fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int)
|
41 | 24 | fp16_scale = torch.rand(OC).half() + 0.5
|
42 | 25 | fp16_activation = torch.rand(BS, IC).half() + 0.5
|
43 |
| - return fp6_weight, fp16_scale, fp16_activation |
44 |
| - |
45 |
| - def test_prepack_fp6_weight(self): |
46 |
| - OC = 256 |
47 |
| - IC = 256 |
48 |
| - fp6_weight, _, _ = self._create_fp6_inputs(0, OC, IC) |
49 |
| - |
50 |
| - # smoke test |
51 |
| - torchao.ops.prepack_fp6_weight(fp6_weight) |
52 |
| - |
53 |
| - # comprehensive testing |
54 |
| - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] |
55 |
| - opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) |
56 |
| - |
57 |
| - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
58 |
| - def test_fp16_to_fp6_original(self): |
59 |
| - OC = 256 |
60 |
| - IC = 256 |
61 |
| - fp16_weight = torch.randn((OC, IC), dtype=torch.float16) |
62 |
| - |
63 |
| - # the original FP16->FP6 kernel checks for overflow/underflow |
64 |
| - fp16_weight.clip_(-28.0, 28.0) |
65 |
| - fp16_weight[fp16_weight.abs() < 0.0625] = 0.0 |
66 |
| - |
67 |
| - # smoke test |
68 |
| - torchao.ops.fp16_to_fp6_original(fp16_weight) |
69 |
| - |
70 |
| - # comprehensive testing |
71 |
| - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] |
72 |
| - opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils) |
| 26 | + return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) |
73 | 27 |
|
74 | 28 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
75 | 29 | def test_fp16act_fp6weight_linear(self):
|
76 | 30 | BS = 2
|
77 | 31 | OC = 256
|
78 | 32 | IC = 256
|
79 | 33 | splitK = 1
|
80 |
| - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) |
81 |
| - |
82 |
| - fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) |
83 |
| - act_cuda = fp16_activation.cuda() |
84 |
| - weight_cuda = fp6_weight_packed.cuda() |
85 |
| - scale_cuda = fp16_scale.cuda() |
| 34 | + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") |
86 | 35 |
|
87 | 36 | # smoke test
|
88 |
| - torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) |
| 37 | + torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK) |
89 | 38 |
|
90 | 39 | # comprehensive testing
|
91 | 40 | test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
|
92 |
| - opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) |
| 41 | + opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) |
93 | 42 |
|
94 | 43 | # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
|
95 | 44 | @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
|
96 | 45 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
97 | 46 | def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):
|
98 |
| - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC) |
99 |
| - |
100 |
| - fp6_weight_packed = torchao.ops.prepack_fp6_weight(fp6_weight) |
101 |
| - act_cuda = fp16_activation.cuda() |
102 |
| - weight_cuda = fp6_weight_packed.cuda() |
103 |
| - scale_cuda = fp16_scale.cuda() |
| 47 | + fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") |
104 | 48 |
|
105 |
| - results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) |
| 49 | + results_fp6 = torchao.ops.fp16act_fp6weight_linear(fp16_activation, fp6_weight, fp16_scale, splitK) |
106 | 50 |
|
107 |
| - fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] |
108 |
| - results_fp16 = act_cuda @ fp16_weight.cuda().T |
| 51 | + fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] |
| 52 | + results_fp16 = fp16_activation @ fp16_weight.T |
109 | 53 |
|
110 | 54 | error = (results_fp6 - results_fp16).abs()
|
111 | 55 | relative_error = error / results_fp16.abs()
|
|
0 commit comments