Skip to content

Commit 680cec9

Browse files
committed
build setup for unit test, enable per-row/per-tensor granuarity
1 parent 062f3cc commit 680cec9

File tree

2 files changed

+89
-38
lines changed

2 files changed

+89
-38
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,28 @@ def forward(self, x):
5858
class TestInt8Tensor(TorchAOIntegrationTestCase):
5959
def setUp(self):
6060
super().setUp()
61+
62+
self.test_shape = (4, 3)
63+
self.dtype = torch.bfloat16
64+
self.batch_size = 32
65+
self.int8_min = -128
66+
self.int8_max = 127
67+
6168
torch.manual_seed(42)
62-
self.weight_fp = torch.randn(4, 3, dtype=torch.bfloat16)
63-
self.input_fp = torch.randn(4, 3, dtype=torch.bfloat16)
64-
self.bias = torch.randn(4, dtype=torch.bfloat16)
65-
self.block_size = [4, 3]
69+
self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype)
70+
self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype)
71+
self.bias = torch.randn(self.test_shape[0], dtype=self.dtype)
72+
self.block_size = list(self.test_shape)
6673

6774
def test_creation_and_attributes(self):
6875
"""Test tensor creation, dtypes, and ranges"""
6976
tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size)
7077

71-
self.assertEqual(tensor.shape, (4, 3))
78+
self.assertEqual(tensor.shape, self.test_shape)
7279
self.assertEqual(tensor.qdata.dtype, torch.int8)
7380
self.assertTrue(
74-
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127)
81+
torch.all(tensor.qdata >= self.int8_min)
82+
and torch.all(tensor.qdata <= self.int8_max)
7583
)
7684

7785
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
@@ -88,12 +96,13 @@ def test_creation_and_attributes(self):
8896
Int8WeightOnlyConfig(version=2),
8997
],
9098
)
91-
def test_int8_linear_variants(
99+
def test_int8_linear_quantization_accuracy(
92100
self,
93101
dtype: torch.dtype,
94102
sizes: tuple,
95103
config,
96104
):
105+
"""Test quantization preserves reasonable accuracy"""
97106
M, N, K = sizes
98107
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
99108

@@ -108,14 +117,16 @@ def test_int8_linear_variants(
108117
output_quantized = m_q(input_tensor)
109118

110119
error = compute_error(output_original, output_quantized)
111-
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
120+
assert error > 20, (
121+
f"Quantization quality is too low, SQNR: {error}dB (expected > {20}dB)"
122+
)
112123

113124
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
114-
def test_static_dynamic_quantization(self, dtype):
115-
"""Test static and dynamic quantization"""
125+
def test_quantization_shapes(self, dtype):
126+
"""Test static and dynamic quantization output shapes"""
116127
K, N = 128, 64
117128
weight = torch.randn(N, K, dtype=dtype, device="cuda")
118-
input_tensor = torch.randn(32, K, dtype=dtype, device="cuda")
129+
input_tensor = torch.randn(self.batch_size, K, dtype=dtype, device="cuda")
119130

120131
# Dynamic quantization (runtime scale computation)
121132
dynamic_tensor = Int8Tensor.from_hp(weight, block_size=[N, K])
@@ -126,8 +137,8 @@ def test_static_dynamic_quantization(self, dtype):
126137
mapping_type=MappingType.SYMMETRIC,
127138
block_size=(input_tensor.shape[0], K),
128139
target_dtype=torch.int8,
129-
quant_min=-128,
130-
quant_max=127,
140+
quant_min=self.int8_min,
141+
quant_max=self.int8_max,
131142
scale_dtype=dtype,
132143
zero_point_dtype=torch.int8,
133144
)
@@ -145,25 +156,29 @@ def test_static_dynamic_quantization(self, dtype):
145156
dynamic_output = torch.nn.functional.linear(input_tensor, dynamic_tensor)
146157
static_output = torch.nn.functional.linear(input_tensor, static_tensor)
147158

148-
self.assertEqual(dynamic_output.shape, (32, N))
149-
self.assertEqual(static_output.shape, (32, N))
159+
expected_shape = (self.batch_size, N)
160+
self.assertEqual(dynamic_output.shape, expected_shape)
161+
self.assertEqual(static_output.shape, expected_shape)
150162
self.assertEqual(dynamic_output.dtype, dtype)
151163
self.assertEqual(static_output.dtype, dtype)
152164

153165
@unittest.skip("granularity parameter not supported in current API")
154166
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
155167
def test_slice_preserves_aliasing(self, granularity):
168+
slice_size = 512
169+
tensor_size = 1024
170+
156171
config = Int8DynamicActivationInt8WeightConfig(
157172
granularity=granularity, version=2
158173
)
159-
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
174+
l = torch.nn.Linear(tensor_size, tensor_size).to("cuda").to(torch.bfloat16)
160175
l.weight = torch.nn.Parameter(
161-
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
176+
torch.zeros(tensor_size, tensor_size, dtype=torch.bfloat16, device="cuda")
162177
)
163178
quantize_(l, config)
164179
param = l.weight
165180
param_data = param.data
166-
param_data = param_data.narrow(0, 0, 512)
181+
param_data = param_data.narrow(0, 0, slice_size)
167182
# Making sure the aliasing is preserved in sliced quantized Tensor
168183
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
169184
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
@@ -179,20 +194,27 @@ def test_slice_preserves_aliasing(self, granularity):
179194
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
180195
def test_slice(self, config, device, dtype):
181196
"""Test tensor slicing"""
182-
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
197+
tensor_size = 256
198+
slice_sizes = (64, 128)
199+
200+
dummy = torch.nn.Linear(
201+
tensor_size, tensor_size, bias=False, dtype=dtype, device=device
202+
)
183203
quantize_(dummy, config)
184204

185-
weight1 = dummy.weight.clone().narrow(0, 0, 64)
186-
weight2 = dummy.weight.clone().narrow(1, 0, 128)
205+
weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0])
206+
weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1])
187207

188-
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64))
189-
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128))
208+
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
209+
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))
190210

191211
# Int8DynamicActivationInt8WeightConfig uses per-row (PerRow)
192212
# Int8WeightOnlyConfig uses per-tensor (PerTensor)
193213
if isinstance(config, Int8DynamicActivationInt8WeightConfig):
194214
# PerRow: dim 0 slicing affects scale, dim 1 doesn't
195-
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))
215+
self.assertEqual(
216+
weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])
217+
)
196218
self.assertEqual(weight2.scale, dummy.weight.scale)
197219
else:
198220
# PerTensor: scale unchanged by slicing
@@ -211,20 +233,33 @@ def test_index_select(self):
211233
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
212234
)
213235

214-
def test_error_handling_and_dequant(self):
215-
"""Test input validation and dequantization accuracy"""
216-
with self.assertRaises((AssertionError, ValueError, RuntimeError)):
217-
Int8Tensor.from_hp(torch.randn(5), [1])
236+
def test_invalid_input_handling(self):
237+
"""Test input validation with specific error types"""
238+
invalid_tensor = torch.randn(5)
239+
incompatible_block_size = [1]
218240

219-
with self.assertRaises((AssertionError, ValueError, RuntimeError)):
241+
with self.assertRaises(
242+
ValueError, msg="Should reject incompatible tensor dimensions"
243+
):
244+
Int8Tensor.from_hp(invalid_tensor, incompatible_block_size)
245+
246+
with self.assertRaises(
247+
ValueError, msg="Should reject mismatched block size dimensions"
248+
):
220249
Int8Tensor.from_hp(self.weight_fp, [1])
221250

251+
def test_dequantization_accuracy(self):
252+
"""Test dequantization accuracy separately"""
222253
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16)
223254
tensor = Int8Tensor.from_hp(test_data, [1, 2])
224255

225256
dequantized = tensor.dequantize()
226257
self.assertEqual(dequantized.shape, test_data.shape)
227-
self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1)
258+
self.assertLess(
259+
torch.abs(dequantized - test_data).max().item(),
260+
0.1,
261+
msg=f"Dequantization error exceeds tolerance of {0.1}",
262+
)
228263

229264

230265
if __name__ == "__main__":

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def from_hp(
133133

134134
if tuple(block_size) == w.shape:
135135
# per-tensor
136-
scale = scale.expand(w.shape)
136+
pass
137137
elif len(scale.shape) == 1:
138138
# per-row, 1D -> 2D
139139
scale = scale.unsqueeze(-1)
@@ -190,22 +190,38 @@ def _(func, types, args, kwargs):
190190
w_vals_t = weight_tensor.qdata.contiguous().t()
191191
w_scales = weight_tensor.scale
192192

193-
tmp = x_vals.reshape(-1, x_vals.shape[-1])
194-
x_scales_dtype = x_scales.dtype
193+
tmp_shape = (-1, x_vals.shape[-1])
194+
tmp = x_vals.view(tmp_shape)
195195

196196
# Cast fp16 scale to float
197197
intermediate_dtype = (
198-
torch.float if x_scales_dtype == torch.half else x_scales_dtype
198+
torch.float if x_scales.dtype == torch.half else x_scales.dtype
199199
)
200200
# Note: CUDA doesn't support int32/int64 matmul, so we convert to float
201201
# Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int'
202202
# This may introduce minor numerical differences compared to int arithmetic
203203
y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype))
204-
y_dot_scaled = y_dot * x_scales.reshape(-1, 1).to(intermediate_dtype)
205204

206-
result = (y_dot_scaled * w_scales).reshape(
207-
*x_vals.shape[:-1], y_dot_scaled.shape[-1]
208-
)
205+
# Apply activation scale
206+
is_per_tensor_act = x_scales.numel() == 1
207+
if is_per_tensor_act:
208+
y_dot.mul_(x_scales.to(intermediate_dtype))
209+
else:
210+
# For block-wise activation scale, reshape to match y_dot
211+
x_scales_reshaped = x_scales.view(y_dot.shape[0], -1)
212+
y_dot.mul_(x_scales_reshaped.to(intermediate_dtype))
213+
214+
# Apply weight scale
215+
is_per_tensor_weight = w_scales.numel() == 1
216+
if is_per_tensor_weight:
217+
result = y_dot.mul_(w_scales.to(intermediate_dtype))
218+
else:
219+
# Per-row weight scale - transpose and broadcast
220+
w_scales_broadcast = w_scales.t().expand_as(y_dot)
221+
result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype))
222+
223+
# Reshape back to original shape
224+
result = result.view(*x_vals.shape[:-1], result.shape[-1])
209225
result = result.to(activation_tensor.dtype)
210226
else:
211227
# FP × INT8 (weight-only)

0 commit comments

Comments
 (0)