Skip to content

Commit 0f51ee6

Browse files
committed
update linear variant, kernel detection test
- Configs are updated to global variants
1 parent caaba7a commit 0f51ee6

File tree

2 files changed

+65
-90
lines changed

2 files changed

+65
-90
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 54 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
Int8WeightOnlyConfig,
1717
quantize_,
1818
)
19-
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
20-
Int8Tensor,
21-
)
2219
from torchao.quantization.utils import compute_error
2320
from torchao.testing.utils import TorchAOIntegrationTestCase
2421

@@ -56,7 +53,7 @@ class TestInt8Tensor(TorchAOIntegrationTestCase):
5653
def setUp(self):
5754
super().setUp()
5855

59-
self.test_shape = (4, 3)
56+
self.test_shape = (32, 20)
6057
self.dtype = torch.bfloat16
6158
self.batch_size = 32
6259

@@ -66,9 +63,26 @@ def setUp(self):
6663
self.bias = torch.randn(self.test_shape[0], dtype=self.dtype)
6764
self.block_size = list(self.test_shape)
6865

69-
def test_creation_and_attributes(self):
66+
@common_utils.parametrize(
67+
"config",
68+
[
69+
Int8DynamicActivationInt8WeightConfig(version=2),
70+
Int8WeightOnlyConfig(version=2),
71+
],
72+
)
73+
def test_creation_and_attributes(self, config):
7074
"""Test tensor creation, dtypes, and ranges"""
71-
tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size)
75+
linear = torch.nn.Linear(
76+
self.test_shape[1],
77+
self.test_shape[0],
78+
bias=False,
79+
dtype=self.dtype,
80+
device="cuda",
81+
)
82+
linear.weight.data = self.weight_fp.cuda()
83+
quantize_(linear, config)
84+
85+
tensor = linear.weight
7286

7387
self.assertEqual(tensor.shape, self.test_shape)
7488
self.assertEqual(tensor.qdata.dtype, torch.int8)
@@ -117,46 +131,6 @@ def test_int8_linear_variants(
117131
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
118132
)
119133

120-
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
121-
@common_utils.parametrize(
122-
"config",
123-
[
124-
Int8DynamicActivationInt8WeightConfig(version=2),
125-
Int8WeightOnlyConfig(version=2),
126-
],
127-
)
128-
@common_utils.parametrize(
129-
"sizes",
130-
[
131-
((128,), 256, 128), # 2D
132-
((32, 128), 64, 256), # 3D
133-
],
134-
)
135-
def test_int8_linear_quantization_accuracy(
136-
self,
137-
dtype: torch.dtype,
138-
sizes: tuple,
139-
config,
140-
):
141-
"""Test quantization preserves reasonable accuracy"""
142-
M, N, K = sizes
143-
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
144-
145-
# Create a linear layer
146-
m = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
147-
m_q = copy.deepcopy(m)
148-
149-
# Quantize
150-
quantize_(m_q, config)
151-
152-
output_fp = m(input_tensor)
153-
output_quantized = m_q(input_tensor)
154-
155-
error = compute_error(output_fp, output_quantized)
156-
assert error > 20, (
157-
f"Quantization quality is too low, SQNR: {error}dB (expected > {20}dB)"
158-
)
159-
160134
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
161135
@common_utils.parametrize(
162136
"config",
@@ -218,36 +192,42 @@ def test_slice(self, config, device, dtype):
218192
with self.assertRaises(NotImplementedError):
219193
_ = dummy.weight[::2]
220194

221-
def test_index_select(self):
222-
"""test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`."""
195+
@common_utils.parametrize(
196+
"config",
197+
[
198+
Int8DynamicActivationInt8WeightConfig(version=2),
199+
Int8WeightOnlyConfig(version=2),
200+
],
201+
)
202+
def test_index_select(self, config):
203+
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
223204
N, K = 256, 512
224205
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
225-
x_int8 = Int8Tensor.from_hp(x, block_size=[N, K])
206+
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
207+
linear.weight.data = x
208+
quantize_(linear, config)
209+
210+
x_int8 = linear.weight
226211
x_int8_0 = x_int8[0]
227212
torch.testing.assert_close(
228213
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
229214
)
230215

231-
def test_invalid_input_handling(self):
232-
"""Test input validation with specific error types"""
233-
invalid_tensor = torch.randn(5)
234-
incompatible_block_size = [1]
235-
236-
with self.assertRaises(
237-
ValueError, msg="Should reject incompatible tensor dimensions"
238-
):
239-
Int8Tensor.from_hp(invalid_tensor, incompatible_block_size)
240-
241-
with self.assertRaises(
242-
ValueError, msg="Should reject mismatched block size dimensions"
243-
):
244-
Int8Tensor.from_hp(self.weight_fp, [1])
245-
246-
def test_dequantization_accuracy(self):
216+
@common_utils.parametrize(
217+
"config",
218+
[
219+
Int8DynamicActivationInt8WeightConfig(version=2),
220+
Int8WeightOnlyConfig(version=2),
221+
],
222+
)
223+
def test_dequantization_accuracy(self, config):
247224
"""Test dequantization accuracy separately"""
248-
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16)
249-
tensor = Int8Tensor.from_hp(test_data, [1, 2])
225+
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16, device="cuda")
226+
linear = torch.nn.Linear(2, 1, bias=False, dtype=torch.bfloat16, device="cuda")
227+
linear.weight.data = test_data
228+
quantize_(linear, config)
250229

230+
tensor = linear.weight
251231
dequantized = tensor.dequantize()
252232
self.assertEqual(dequantized.shape, test_data.shape)
253233
self.assertLess(
@@ -268,23 +248,14 @@ def test_available_gpu_kernels(self):
268248
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
269249

270250
out, code = run_and_get_code(m, x)
271-
kernels = {}
272-
273-
# Check for Triton kernels
274-
if "torch.ops.triton" in code[0]:
275-
kernels["triton"] = True
276-
print("Triton kernels are available for int8 quantization")
277-
else:
278-
kernels["triton"] = False
279-
print("Triton kernels are NOT available for int8 quantization")
251+
has_triton = "triton" in code[0].lower() # Trition
252+
has_fbgemm = "fbgemm" in code[0].lower() # FB-GEMM
253+
has_int_mm = "_int_mm" in code[0] # Int8 MatMul
280254

281-
# Check for FBGEMM kernels
282-
if "torch.ops.fbgemm" in code[0]:
283-
kernels["fbgemm"] = True
284-
print("FBGEMM kernels are available for int8 quantization")
285-
else:
286-
kernels["fbgemm"] = False
287-
print("FBGEMM kernels are NOT available for int8 quantization")
255+
self.assertTrue(
256+
has_triton or has_fbgemm or has_int_mm,
257+
f"No int8 quantization kernels found. has_triton={has_triton}, has_fbgemm={has_fbgemm}, has_int_mm={has_int_mm}",
258+
)
288259

289260

290261
if __name__ == "__main__":

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
143143
output_dtype = self.dtype
144144

145145
qdata_fp = self.qdata.to(output_dtype)
146-
# Reshape scale to broadcast if granularity is block-wise
147-
scale_expanded = _maybe_expand_scale_to_tensor_shape(
148-
self.scale, self.qdata.shape
149-
)
146+
scale = self.scale
147+
while scale.ndim < qdata_fp.ndim:
148+
scale = scale.unsqueeze(-1)
149+
150+
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, qdata_fp.shape)
150151
return qdata_fp * scale_expanded.to(output_dtype)
151152

152153

@@ -276,16 +277,19 @@ def _(func, types, args, kwargs):
276277
self, dim, index = args
277278
assert dim == 0, f"Only dim=0 supported, got {dim}"
278279

279-
selected_scale = self.scale if self.scale.ndim == 0 else self.scale[index]
280+
selected_qdata = self.qdata[index]
281+
selected_scale = _slice_scale_for_dimension(
282+
self.scale, self.qdata.shape, dim, index, index + 1, step=1
283+
).squeeze(0)
280284

281285
return return_and_correct_aliasing(
282286
func,
283287
args,
284288
kwargs,
285289
Int8Tensor(
286-
self.qdata[index],
290+
selected_qdata,
287291
selected_scale,
288-
self.block_size[1:],
292+
[selected_qdata.shape[-1]],
289293
self.act_quant_kwargs,
290294
self.dtype,
291295
),

0 commit comments

Comments
 (0)