Skip to content

Commit 305c3a9

Browse files
committed
update default granularity, kernel test
1 parent 0f51ee6 commit 305c3a9

File tree

2 files changed

+18
-34
lines changed

2 files changed

+18
-34
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from torch._inductor.utils import run_and_get_code
12+
from torch.testing import FileCheck
1213
from torch.testing._internal import common_utils
1314

1415
from torchao.quantization import (
@@ -145,12 +146,8 @@ def test_per_row_scale_shape(self, dtype, config):
145146
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")
146147
quantize_(linear, config)
147148

148-
# Dynamic: per-row (1D scale [N]), Weight-only: per-tensor (scalar)
149-
if isinstance(config, Int8DynamicActivationInt8WeightConfig):
150-
self.assertEqual(linear.weight.scale.shape, (N,))
151-
self.assertEqual(linear.weight.scale.ndim, 1)
152-
else:
153-
self.assertEqual(linear.weight.scale.numel(), 1)
149+
self.assertEqual(linear.weight.scale.shape, (N,))
150+
self.assertEqual(linear.weight.scale.ndim, 1)
154151

155152
@common_utils.parametrize(
156153
"config",
@@ -162,7 +159,7 @@ def test_per_row_scale_shape(self, dtype, config):
162159
@common_utils.parametrize("device", ["cpu", "cuda"])
163160
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
164161
def test_slice(self, config, device, dtype):
165-
"""Test tensor slicing"""
162+
"""Test tensor slicing with per-row quantization"""
166163
tensor_size = 256
167164
slice_sizes = (64, 128)
168165

@@ -176,19 +173,8 @@ def test_slice(self, config, device, dtype):
176173

177174
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
178175
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))
179-
180-
# Int8DynamicActivationInt8WeightConfig uses per-row (PerRow)
181-
# Int8WeightOnlyConfig uses per-tensor (PerTensor)
182-
if isinstance(config, Int8DynamicActivationInt8WeightConfig):
183-
# PerRow: dim 0 slicing affects scale, dim 1 doesn't
184-
self.assertEqual(
185-
weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])
186-
)
187-
self.assertEqual(weight2.scale, dummy.weight.scale)
188-
else:
189-
# PerTensor: scale unchanged by slicing
190-
self.assertEqual(weight1.scale, dummy.weight.scale)
191-
self.assertEqual(weight2.scale, dummy.weight.scale)
176+
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]))
177+
self.assertEqual(weight2.scale, dummy.weight.scale)
192178
with self.assertRaises(NotImplementedError):
193179
_ = dummy.weight[::2]
194180

@@ -230,13 +216,15 @@ def test_dequantization_accuracy(self, config):
230216
tensor = linear.weight
231217
dequantized = tensor.dequantize()
232218
self.assertEqual(dequantized.shape, test_data.shape)
233-
self.assertLess(
234-
torch.abs(dequantized - test_data).max().item(),
235-
0.1,
236-
msg=f"Dequantization error exceeds tolerance of {0.1}",
219+
assert compute_error(dequantized, test_data) > 20, (
220+
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}"
237221
)
238222

239-
def test_available_gpu_kernels(self):
223+
@common_utils.parametrize(
224+
"kernel",
225+
["triton_per_fused", "extern_kernels._int_mm", "triton_poi_fused"],
226+
)
227+
def test_available_gpu_kernels(self, kernel):
240228
"""Check which GPU kernels are available"""
241229
M, K, N = 128, 256, 512
242230
m = torch.nn.Sequential(
@@ -248,14 +236,7 @@ def test_available_gpu_kernels(self):
248236
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
249237

250238
out, code = run_and_get_code(m, x)
251-
has_triton = "triton" in code[0].lower() # Trition
252-
has_fbgemm = "fbgemm" in code[0].lower() # FB-GEMM
253-
has_int_mm = "_int_mm" in code[0] # Int8 MatMul
254-
255-
self.assertTrue(
256-
has_triton or has_fbgemm or has_int_mm,
257-
f"No int8 quantization kernels found. has_triton={has_triton}, has_fbgemm={has_fbgemm}, has_int_mm={has_int_mm}",
258-
)
239+
FileCheck().check(kernel).run(code[0])
259240

260241

261242
if __name__ == "__main__":

torchao/quantization/quant_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,10 @@ def _int8_weight_only_quantize_tensor(weight, config):
13871387
)
13881388
else:
13891389
assert config.version == 2, f"Unexpected version: {config.version}"
1390-
block_size = [weight.shape[0], weight.shape[1]]
1390+
group_size = config.group_size
1391+
if group_size is None:
1392+
group_size = weight.shape[-1]
1393+
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
13911394
new_weight = Int8Tensor.from_hp(weight, block_size=block_size)
13921395
return new_weight
13931396

0 commit comments

Comments
 (0)