Skip to content

Commit b516304

Browse files
committed
merge test cases with cleanup
1 parent 3ab38ba commit b516304

File tree

1 file changed

+7
-28
lines changed

1 file changed

+7
-28
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ def setUp(self):
5959
self.batch_size = 32
6060

6161
torch.manual_seed(42)
62-
self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype)
63-
self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype)
64-
self.bias = torch.randn(self.test_shape[0], dtype=self.dtype)
65-
self.block_size = list(self.test_shape)
6662

6763
@common_utils.parametrize(
6864
"config",
@@ -80,16 +76,13 @@ def test_creation_and_attributes(self, config):
8076
dtype=self.dtype,
8177
device="cuda",
8278
)
83-
linear.weight.data = self.weight_fp.cuda()
8479
quantize_(linear, config)
8580

86-
tensor = linear.weight
81+
w = linear.weight
8782

88-
self.assertEqual(tensor.shape, self.test_shape)
89-
self.assertEqual(tensor.qdata.dtype, torch.int8)
90-
self.assertTrue(
91-
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127)
92-
)
83+
self.assertEqual(w.shape, self.test_shape)
84+
self.assertEqual(w.qdata.dtype, torch.int8)
85+
self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127))
9386

9487
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
9588
@common_utils.parametrize("compile", [True, False])
@@ -122,6 +115,9 @@ def test_int8_linear_variants(
122115

123116
quantize_(model_q, config)
124117

118+
self.assertEqual(model_q.linear2.weight.scale.shape, (K,))
119+
self.assertEqual(model_q.linear2.weight.scale.ndim, 1)
120+
125121
if compile:
126122
model_q = torch.compile(model_q, fullgraph=True)
127123

@@ -132,23 +128,6 @@ def test_int8_linear_variants(
132128
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
133129
)
134130

135-
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
136-
@common_utils.parametrize(
137-
"config",
138-
[
139-
Int8DynamicActivationInt8WeightConfig(version=2),
140-
Int8WeightOnlyConfig(version=2),
141-
],
142-
)
143-
def test_per_row_scale_shape(self, dtype, config):
144-
"""Test per-row quantization maintains 1D scale"""
145-
N, K = 64, 128
146-
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")
147-
quantize_(linear, config)
148-
149-
self.assertEqual(linear.weight.scale.shape, (N,))
150-
self.assertEqual(linear.weight.scale.ndim, 1)
151-
152131
@common_utils.parametrize(
153132
"config",
154133
[

0 commit comments

Comments
 (0)