@@ -58,20 +58,28 @@ def forward(self, x):
5858class TestInt8Tensor (TorchAOIntegrationTestCase ):
5959 def setUp (self ):
6060 super ().setUp ()
61+
62+ self .test_shape = (4 , 3 )
63+ self .dtype = torch .bfloat16
64+ self .batch_size = 32
65+ self .int8_min = - 128
66+ self .int8_max = 127
67+
6168 torch .manual_seed (42 )
62- self .weight_fp = torch .randn (4 , 3 , dtype = torch . bfloat16 )
63- self .input_fp = torch .randn (4 , 3 , dtype = torch . bfloat16 )
64- self .bias = torch .randn (4 , dtype = torch . bfloat16 )
65- self .block_size = [ 4 , 3 ]
69+ self .weight_fp = torch .randn (* self . test_shape , dtype = self . dtype )
70+ self .input_fp = torch .randn (* self . test_shape , dtype = self . dtype )
71+ self .bias = torch .randn (self . test_shape [ 0 ] , dtype = self . dtype )
72+ self .block_size = list ( self . test_shape )
6673
6774 def test_creation_and_attributes (self ):
6875 """Test tensor creation, dtypes, and ranges"""
6976 tensor = Int8Tensor .from_hp (self .weight_fp , self .block_size )
7077
71- self .assertEqual (tensor .shape , ( 4 , 3 ) )
78+ self .assertEqual (tensor .shape , self . test_shape )
7279 self .assertEqual (tensor .qdata .dtype , torch .int8 )
7380 self .assertTrue (
74- torch .all (tensor .qdata >= - 128 ) and torch .all (tensor .qdata <= 127 )
81+ torch .all (tensor .qdata >= self .int8_min )
82+ and torch .all (tensor .qdata <= self .int8_max )
7583 )
7684
7785 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
@@ -88,12 +96,13 @@ def test_creation_and_attributes(self):
8896 Int8WeightOnlyConfig (version = 2 ),
8997 ],
9098 )
91- def test_int8_linear_variants (
99+ def test_int8_linear_quantization_accuracy (
92100 self ,
93101 dtype : torch .dtype ,
94102 sizes : tuple ,
95103 config ,
96104 ):
105+ """Test quantization preserves reasonable accuracy"""
97106 M , N , K = sizes
98107 input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
99108
@@ -108,14 +117,16 @@ def test_int8_linear_variants(
108117 output_quantized = m_q (input_tensor )
109118
110119 error = compute_error (output_original , output_quantized )
111- assert error > 20 , f"Quantization error is too high got a SQNR of { error } "
120+ assert error > 20 , (
121+ f"Quantization quality is too low, SQNR: { error } dB (expected > { 20 } dB)"
122+ )
112123
113124 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
114- def test_static_dynamic_quantization (self , dtype ):
115- """Test static and dynamic quantization"""
125+ def test_quantization_shapes (self , dtype ):
126+ """Test static and dynamic quantization output shapes """
116127 K , N = 128 , 64
117128 weight = torch .randn (N , K , dtype = dtype , device = "cuda" )
118- input_tensor = torch .randn (32 , K , dtype = dtype , device = "cuda" )
129+ input_tensor = torch .randn (self . batch_size , K , dtype = dtype , device = "cuda" )
119130
120131 # Dynamic quantization (runtime scale computation)
121132 dynamic_tensor = Int8Tensor .from_hp (weight , block_size = [N , K ])
@@ -126,8 +137,8 @@ def test_static_dynamic_quantization(self, dtype):
126137 mapping_type = MappingType .SYMMETRIC ,
127138 block_size = (input_tensor .shape [0 ], K ),
128139 target_dtype = torch .int8 ,
129- quant_min = - 128 ,
130- quant_max = 127 ,
140+ quant_min = self . int8_min ,
141+ quant_max = self . int8_max ,
131142 scale_dtype = dtype ,
132143 zero_point_dtype = torch .int8 ,
133144 )
@@ -145,25 +156,29 @@ def test_static_dynamic_quantization(self, dtype):
145156 dynamic_output = torch .nn .functional .linear (input_tensor , dynamic_tensor )
146157 static_output = torch .nn .functional .linear (input_tensor , static_tensor )
147158
148- self .assertEqual (dynamic_output .shape , (32 , N ))
149- self .assertEqual (static_output .shape , (32 , N ))
159+ expected_shape = (self .batch_size , N )
160+ self .assertEqual (dynamic_output .shape , expected_shape )
161+ self .assertEqual (static_output .shape , expected_shape )
150162 self .assertEqual (dynamic_output .dtype , dtype )
151163 self .assertEqual (static_output .dtype , dtype )
152164
153165 @unittest .skip ("granularity parameter not supported in current API" )
154166 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
155167 def test_slice_preserves_aliasing (self , granularity ):
168+ slice_size = 512
169+ tensor_size = 1024
170+
156171 config = Int8DynamicActivationInt8WeightConfig (
157172 granularity = granularity , version = 2
158173 )
159- l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
174+ l = torch .nn .Linear (tensor_size , tensor_size ).to ("cuda" ).to (torch .bfloat16 )
160175 l .weight = torch .nn .Parameter (
161- torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
176+ torch .zeros (tensor_size , tensor_size , dtype = torch .bfloat16 , device = "cuda" )
162177 )
163178 quantize_ (l , config )
164179 param = l .weight
165180 param_data = param .data
166- param_data = param_data .narrow (0 , 0 , 512 )
181+ param_data = param_data .narrow (0 , 0 , slice_size )
167182 # Making sure the aliasing is preserved in sliced quantized Tensor
168183 assert param .data .qdata .data_ptr () == param_data .qdata .data_ptr ()
169184 assert param .data .scale .data_ptr () == param_data .scale .data_ptr ()
@@ -179,20 +194,27 @@ def test_slice_preserves_aliasing(self, granularity):
179194 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
180195 def test_slice (self , config , device , dtype ):
181196 """Test tensor slicing"""
182- dummy = torch .nn .Linear (256 , 256 , bias = False , dtype = dtype , device = device )
197+ tensor_size = 256
198+ slice_sizes = (64 , 128 )
199+
200+ dummy = torch .nn .Linear (
201+ tensor_size , tensor_size , bias = False , dtype = dtype , device = device
202+ )
183203 quantize_ (dummy , config )
184204
185- weight1 = dummy .weight .clone ().narrow (0 , 0 , 64 )
186- weight2 = dummy .weight .clone ().narrow (1 , 0 , 128 )
205+ weight1 = dummy .weight .clone ().narrow (0 , 0 , slice_sizes [ 0 ] )
206+ weight2 = dummy .weight .clone ().narrow (1 , 0 , slice_sizes [ 1 ] )
187207
188- self .assertEqual (weight1 .qdata , dummy .weight .qdata .narrow (0 , 0 , 64 ))
189- self .assertEqual (weight2 .qdata , dummy .weight .qdata .narrow (1 , 0 , 128 ))
208+ self .assertEqual (weight1 .qdata , dummy .weight .qdata .narrow (0 , 0 , slice_sizes [ 0 ] ))
209+ self .assertEqual (weight2 .qdata , dummy .weight .qdata .narrow (1 , 0 , slice_sizes [ 1 ] ))
190210
191211 # Int8DynamicActivationInt8WeightConfig uses per-row (PerRow)
192212 # Int8WeightOnlyConfig uses per-tensor (PerTensor)
193213 if isinstance (config , Int8DynamicActivationInt8WeightConfig ):
194214 # PerRow: dim 0 slicing affects scale, dim 1 doesn't
195- self .assertEqual (weight1 .scale , dummy .weight .scale .narrow (0 , 0 , 64 ))
215+ self .assertEqual (
216+ weight1 .scale , dummy .weight .scale .narrow (0 , 0 , slice_sizes [0 ])
217+ )
196218 self .assertEqual (weight2 .scale , dummy .weight .scale )
197219 else :
198220 # PerTensor: scale unchanged by slicing
@@ -211,20 +233,33 @@ def test_index_select(self):
211233 x_int8 .dequantize ()[0 ], x_int8_0 .dequantize (), atol = 0 , rtol = 0
212234 )
213235
214- def test_error_handling_and_dequant (self ):
215- """Test input validation and dequantization accuracy """
216- with self . assertRaises (( AssertionError , ValueError , RuntimeError )):
217- Int8Tensor . from_hp ( torch . randn ( 5 ), [1 ])
236+ def test_invalid_input_handling (self ):
237+ """Test input validation with specific error types """
238+ invalid_tensor = torch . randn ( 5 )
239+ incompatible_block_size = [1 ]
218240
219- with self .assertRaises ((AssertionError , ValueError , RuntimeError )):
241+ with self .assertRaises (
242+ ValueError , msg = "Should reject incompatible tensor dimensions"
243+ ):
244+ Int8Tensor .from_hp (invalid_tensor , incompatible_block_size )
245+
246+ with self .assertRaises (
247+ ValueError , msg = "Should reject mismatched block size dimensions"
248+ ):
220249 Int8Tensor .from_hp (self .weight_fp , [1 ])
221250
251+ def test_dequantization_accuracy (self ):
252+ """Test dequantization accuracy separately"""
222253 test_data = torch .tensor ([[1.0 , - 1.0 ]], dtype = torch .bfloat16 )
223254 tensor = Int8Tensor .from_hp (test_data , [1 , 2 ])
224255
225256 dequantized = tensor .dequantize ()
226257 self .assertEqual (dequantized .shape , test_data .shape )
227- self .assertLess (torch .abs (dequantized - test_data ).max ().item (), 0.1 )
258+ self .assertLess (
259+ torch .abs (dequantized - test_data ).max ().item (),
260+ 0.1 ,
261+ msg = f"Dequantization error exceeds tolerance of { 0.1 } " ,
262+ )
228263
229264
230265if __name__ == "__main__" :
0 commit comments